@@ -236,41 +236,21 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
236236 return outIndices;
237237 }
238238
239- // Happy path: Only one output.
240- if (outIndices.size () == 1 ) {
241- SmallVector<StringAttr> inDimNames;
242- // Concatenate input
243- Value x = b.i32_val (0 );
244- int shift = 0 ;
245- for (auto [inDimName, idx] : nonConstantIns) {
246- inDimNames.push_back (inDimName);
247- x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
248- shift += layout.getInDimSizeLog2 (inDimName);
249- }
250- // Flatten ins
251- auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
252- matrix = matrix.flattenIns ();
239+ SmallVector<StringAttr> inDimNames;
240+ // Concatenate input
241+ Value x = b.i32_val (0 );
242+ int shift = 0 ;
243+ for (auto [inDimName, idx] : nonConstantIns) {
244+ inDimNames.push_back (inDimName);
245+ x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
246+ shift += layout.getInDimSizeLog2 (inDimName);
247+ }
248+
249+ for (auto &[outDimName, outIdx] : outIndices) {
250+ // Apply flattened sublayout for this output
251+ auto matrix = layout.sublayout (inDimNames, outDimName).flattenIns ();
253252 auto out = triton::gpu::matrixVectorProd (b, matrix, x);
254- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
255- return outIndices;
256- }
257-
258- for (auto [inDimName, idx] : indices) {
259- APInt constant;
260- if (matchPattern (idx, m_ConstantInt (&constant))) {
261- continue ;
262- }
263- int nBits = layout.getInDimSizeLog2 (inDimName);
264- for (int i = 0 ; i < nBits; i++) {
265- Value bit = b.and_ (idx, b.i32_val (1 << i));
266- Value bit_is_zero = b.icmp_eq (bit, zero);
267- for (auto &[outDimName, outIdx] : outIndices) {
268- int32_t basis = layout.getBasis (inDimName, i, outDimName);
269- if (basis == 0 )
270- continue ;
271- outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
272- }
273- }
253+ outIdx = b.xor_ (outIdx, out);
274254 }
275255
276256 return outIndices;
0 commit comments