@@ -151,6 +151,10 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151
151
auto b = TritonLLVMOpBuilder (loc, rewriter);
152
152
assert (layout.getNumInDims () == indices.size ());
153
153
assert (llvm::equal (layout.getInDimNames (), llvm::make_first_range (indices)));
154
+ // Trivial layout
155
+ if (layout.getNumOutDims () == 0 ) {
156
+ return {};
157
+ }
154
158
155
159
// This function can emit a lot of MLIR code, which ultimately makes
156
160
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -164,62 +168,65 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
164
168
SmallVector<std::pair<StringAttr, int32_t >> constantIns;
165
169
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
166
170
for (auto [inDimName, idx] : indices) {
167
- if ( auto constant = idx. getDefiningOp <LLVM::ConstantOp>()) {
168
- constantIns. push_back (
169
- {inDimName, cast<IntegerAttr>( constant.getValue ()). getInt ()});
171
+ APInt constant;
172
+ if ( matchPattern (idx, m_ConstantInt (&constant))) {
173
+ constantIns. push_back ( {inDimName, constant.getSExtValue ()});
170
174
} else {
171
175
constantIns.push_back ({inDimName, 0 });
172
176
nonConstantIns.push_back ({inDimName, idx});
173
177
}
174
178
}
175
- SmallVector<int32_t > constantComponent =
176
- llvm::to_vector (llvm::make_second_range (layout.apply (constantIns)));
177
179
180
+ // Compute constant part of the output and wrap it as values
178
181
Value zero = b.i32_val (0 );
179
182
SmallVector<std::pair<StringAttr, Value>> outIndices;
180
- for (auto [i, outDimName ] : llvm::enumerate ( layout.getOutDimNames () )) {
181
- if (constantComponent[i] == 0 )
183
+ for (auto [outDimName, constant ] : layout.apply (constantIns )) {
184
+ if (constant == 0 )
182
185
outIndices.push_back ({outDimName, zero});
183
186
else
184
- outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
187
+ outIndices.push_back ({outDimName, b.i32_val (constant)});
188
+ }
189
+
190
+ if (nonConstantIns.size () == 0 ) {
191
+ return outIndices;
185
192
}
186
- // Happy path: Only one output.
187
- if (outIndices.size () == 1 ) {
188
- SmallVector<StringAttr> inDimNames;
189
- // Concatenate input
190
- Value x = b.i32_val (0 );
193
+
194
+ // Concatenate input
195
+ Value x = b.i32_val (0 );
196
+ if (nonConstantIns.size () == 1 ) {
197
+ x = nonConstantIns[0 ].second ;
198
+ } else {
191
199
int shift = 0 ;
192
200
for (auto [inDimName, idx] : nonConstantIns) {
193
- inDimNames.push_back (inDimName);
194
201
x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
195
202
shift += layout.getInDimSizeLog2 (inDimName);
196
203
}
197
- // Flatten ins
198
- auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199
- matrix = matrix.flattenIns ();
200
- auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201
- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202
- return outIndices;
203
204
}
204
205
205
- for (auto [inDimName, idx] : indices) {
206
- if (idx.getDefiningOp <LLVM::ConstantOp>()) {
207
- continue ;
208
- }
206
+ // Remove constant input dims from the layout and flatten it
207
+ auto inDimNames = llvm::to_vector (llvm::make_first_range (nonConstantIns));
208
+ auto matrix = layout.sublayout (
209
+ inDimNames, llvm::to_vector (llvm::make_first_range (outIndices)));
210
+ auto flatMatrix = matrix.flattenIns ().flattenOuts ();
211
+
212
+ // Lower the matrix-vector product
213
+ auto out = triton::gpu::matrixVectorProd (b, flatMatrix, x);
209
214
210
- int nBits = layout.getInDimSizeLog2 (inDimName);
211
- for (int i = 0 ; i < nBits; i++) {
212
- Value bit = b.and_ (idx, b.i32_val (1 << i));
213
- Value bit_is_zero = b.icmp_eq (bit, zero);
214
- for (auto &[outDimName, outIdx] : outIndices) {
215
- int32_t basis = layout.getBasis (inDimName, i, outDimName);
216
- if (basis == 0 )
217
- continue ;
218
- outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
219
- }
215
+ // Unpack the output
216
+ if (matrix.getNumOutDims () == 1 ) {
217
+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
218
+ } else {
219
+ assert (llvm::equal (matrix.getOutDimNames (),
220
+ llvm::make_first_range (outIndices)));
221
+ int shift = 0 ;
222
+ for (auto &[dimName, outIdx] : outIndices) {
223
+ auto outDimSizeLog2 = layout.getOutDimSizeLog2 (dimName);
224
+ auto mask = (1 << outDimSizeLog2) - 1 ;
225
+ outIdx = b.xor_ (outIdx,
226
+ b.and_ (b.lshr (out, b.i32_val (shift)), b.i32_val (mask)));
227
+ shift += outDimSizeLog2;
220
228
}
221
229
}
222
-
223
230
return outIndices;
224
231
}
225
232
0 commit comments