Skip to content

Commit daca739

Browse files
Adjust LIT tests for SimplifyApplyLinearLayout changes (#4771)
Fixes #4662 Reland 47297f5 and adjust LIT tests accordingly. The following tests have been updated: TRITON :: TritonIntelGPU/blockptr_store.mlir TRITON :: TritonIntelGPU/tritonintlgpu-nested-layout.mlir TRITON :: Conversion/intel/dot_layout_offset.mlir TRITON :: Conversion/intel/dpas_to_block_layout_convert.mlir
2 parents 6bc75fc + 089c4bf commit daca739

File tree

5 files changed

+293
-303
lines changed

5 files changed

+293
-303
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -236,41 +236,21 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
236236
return outIndices;
237237
}
238238

239-
// Happy path: Only one output.
240-
if (outIndices.size() == 1) {
241-
SmallVector<StringAttr> inDimNames;
242-
// Concatenate input
243-
Value x = b.i32_val(0);
244-
int shift = 0;
245-
for (auto [inDimName, idx] : nonConstantIns) {
246-
inDimNames.push_back(inDimName);
247-
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
248-
shift += layout.getInDimSizeLog2(inDimName);
249-
}
250-
// Flatten ins
251-
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
252-
matrix = matrix.flattenIns();
239+
SmallVector<StringAttr> inDimNames;
240+
// Concatenate input
241+
Value x = b.i32_val(0);
242+
int shift = 0;
243+
for (auto [inDimName, idx] : nonConstantIns) {
244+
inDimNames.push_back(inDimName);
245+
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
246+
shift += layout.getInDimSizeLog2(inDimName);
247+
}
248+
249+
for (auto &[outDimName, outIdx] : outIndices) {
250+
// Apply flattened sublayout for this output
251+
auto matrix = layout.sublayout(inDimNames, outDimName).flattenIns();
253252
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
254-
outIndices[0].second = b.xor_(outIndices[0].second, out);
255-
return outIndices;
256-
}
257-
258-
for (auto [inDimName, idx] : indices) {
259-
APInt constant;
260-
if (matchPattern(idx, m_ConstantInt(&constant))) {
261-
continue;
262-
}
263-
int nBits = layout.getInDimSizeLog2(inDimName);
264-
for (int i = 0; i < nBits; i++) {
265-
Value bit = b.and_(idx, b.i32_val(1 << i));
266-
Value bit_is_zero = b.icmp_eq(bit, zero);
267-
for (auto &[outDimName, outIdx] : outIndices) {
268-
int32_t basis = layout.getBasis(inDimName, i, outDimName);
269-
if (basis == 0)
270-
continue;
271-
outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis)));
272-
}
273-
}
253+
outIdx = b.xor_(outIdx, out);
274254
}
275255

276256
return outIndices;

0 commit comments

Comments
 (0)