@@ -151,10 +151,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151
151
auto b = TritonLLVMOpBuilder (loc, rewriter);
152
152
assert (layout.getNumInDims () == indices.size ());
153
153
assert (llvm::equal (layout.getInDimNames (), llvm::make_first_range (indices)));
154
- // Trivial layout
155
- if (layout.getNumOutDims () == 0 ) {
156
- return {};
157
- }
158
154
159
155
// This function can emit a lot of MLIR code, which ultimately makes
160
156
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -168,65 +164,62 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
168
164
SmallVector<std::pair<StringAttr, int32_t >> constantIns;
169
165
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
170
166
for (auto [inDimName, idx] : indices) {
171
- APInt constant;
172
- if ( matchPattern (idx, m_ConstantInt (&constant))) {
173
- constantIns. push_back ( {inDimName, constant.getSExtValue ()});
167
+ if ( auto constant = idx. getDefiningOp <LLVM::ConstantOp>()) {
168
+ constantIns. push_back (
169
+ {inDimName, cast<IntegerAttr>( constant.getValue ()). getInt ()});
174
170
} else {
175
171
constantIns.push_back ({inDimName, 0 });
176
172
nonConstantIns.push_back ({inDimName, idx});
177
173
}
178
174
}
175
+ SmallVector<int32_t > constantComponent =
176
+ llvm::to_vector (llvm::make_second_range (layout.apply (constantIns)));
179
177
180
- // Compute constant part of the output and wrap it as values
181
178
Value zero = b.i32_val (0 );
182
179
SmallVector<std::pair<StringAttr, Value>> outIndices;
183
- for (auto [outDimName, constant ] : layout.apply (constantIns )) {
184
- if (constant == 0 )
180
+ for (auto [i, outDimName ] : llvm::enumerate ( layout.getOutDimNames () )) {
181
+ if (constantComponent[i] == 0 )
185
182
outIndices.push_back ({outDimName, zero});
186
183
else
187
- outIndices.push_back ({outDimName, b.i32_val (constant)});
188
- }
189
-
190
- if (nonConstantIns.size () == 0 ) {
191
- return outIndices;
184
+ outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
192
185
}
193
-
194
- // Concatenate input
195
- Value x = b.i32_val (0 );
196
- if (nonConstantIns.size () == 1 ) {
197
- x = nonConstantIns[0 ].second ;
198
- } else {
186
+ // Happy path: Only one output.
187
+ if (outIndices.size () == 1 ) {
188
+ SmallVector<StringAttr> inDimNames;
189
+ // Concatenate input
190
+ Value x = b.i32_val (0 );
199
191
int shift = 0 ;
200
192
for (auto [inDimName, idx] : nonConstantIns) {
193
+ inDimNames.push_back (inDimName);
201
194
x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
202
195
shift += layout.getInDimSizeLog2 (inDimName);
203
196
}
197
+ // Flatten ins
198
+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199
+ matrix = matrix.flattenIns ();
200
+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201
+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202
+ return outIndices;
204
203
}
205
204
206
- // Remove constant input dims from the layout and flatten it
207
- auto inDimNames = llvm::to_vector (llvm::make_first_range (nonConstantIns));
208
- auto matrix = layout.sublayout (
209
- inDimNames, llvm::to_vector (llvm::make_first_range (outIndices)));
210
- auto flatMatrix = matrix.flattenIns ().flattenOuts ();
211
-
212
- // Lower the matrix-vector product
213
- auto out = triton::gpu::matrixVectorProd (b, flatMatrix, x);
205
+ for (auto [inDimName, idx] : indices) {
206
+ if (idx.getDefiningOp <LLVM::ConstantOp>()) {
207
+ continue ;
208
+ }
214
209
215
- // Unpack the output
216
- if (matrix.getNumOutDims () == 1 ) {
217
- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
218
- } else {
219
- assert (llvm::equal (matrix.getOutDimNames (),
220
- llvm::make_first_range (outIndices)));
221
- int shift = 0 ;
222
- for (auto &[dimName, outIdx] : outIndices) {
223
- auto outDimSizeLog2 = layout.getOutDimSizeLog2 (dimName);
224
- auto mask = (1 << outDimSizeLog2) - 1 ;
225
- outIdx = b.xor_ (outIdx,
226
- b.and_ (b.lshr (out, b.i32_val (shift)), b.i32_val (mask)));
227
- shift += outDimSizeLog2;
210
+ int nBits = layout.getInDimSizeLog2 (inDimName);
211
+ for (int i = 0 ; i < nBits; i++) {
212
+ Value bit = b.and_ (idx, b.i32_val (1 << i));
213
+ Value bit_is_zero = b.icmp_eq (bit, zero);
214
+ for (auto &[outDimName, outIdx] : outIndices) {
215
+ int32_t basis = layout.getBasis (inDimName, i, outDimName);
216
+ if (basis == 0 )
217
+ continue ;
218
+ outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
219
+ }
228
220
}
229
221
}
222
+
230
223
return outIndices;
231
224
}
232
225
0 commit comments