@@ -236,41 +236,21 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
236
236
return outIndices;
237
237
}
238
238
239
- // Happy path: Only one output.
240
- if (outIndices.size () == 1 ) {
241
- SmallVector<StringAttr> inDimNames;
242
- // Concatenate input
243
- Value x = b.i32_val (0 );
244
- int shift = 0 ;
245
- for (auto [inDimName, idx] : nonConstantIns) {
246
- inDimNames.push_back (inDimName);
247
- x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
248
- shift += layout.getInDimSizeLog2 (inDimName);
249
- }
250
- // Flatten ins
251
- auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
252
- matrix = matrix.flattenIns ();
239
+ SmallVector<StringAttr> inDimNames;
240
+ // Concatenate input
241
+ Value x = b.i32_val (0 );
242
+ int shift = 0 ;
243
+ for (auto [inDimName, idx] : nonConstantIns) {
244
+ inDimNames.push_back (inDimName);
245
+ x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
246
+ shift += layout.getInDimSizeLog2 (inDimName);
247
+ }
248
+
249
+ for (auto &[outDimName, outIdx] : outIndices) {
250
+ // Apply flattened sublayout for this output
251
+ auto matrix = layout.sublayout (inDimNames, outDimName).flattenIns ();
253
252
auto out = triton::gpu::matrixVectorProd (b, matrix, x);
254
- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
255
- return outIndices;
256
- }
257
-
258
- for (auto [inDimName, idx] : indices) {
259
- APInt constant;
260
- if (matchPattern (idx, m_ConstantInt (&constant))) {
261
- continue ;
262
- }
263
- int nBits = layout.getInDimSizeLog2 (inDimName);
264
- for (int i = 0 ; i < nBits; i++) {
265
- Value bit = b.and_ (idx, b.i32_val (1 << i));
266
- Value bit_is_zero = b.icmp_eq (bit, zero);
267
- for (auto &[outDimName, outIdx] : outIndices) {
268
- int32_t basis = layout.getBasis (inDimName, i, outDimName);
269
- if (basis == 0 )
270
- continue ;
271
- outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
272
- }
273
- }
253
+ outIdx = b.xor_ (outIdx, out);
274
254
}
275
255
276
256
return outIndices;
0 commit comments