Skip to content

Commit c8a711d

Browse files
authored
Partially Revert "[LAYOUTS] Enable diagonal iteration unconditionally (#7218)" (#7245)
We are seeing some internal regressions. This reverts commit 336cc1d.
1 parent dd1c3d4 commit c8a711d

File tree

1 file changed

+35
-42
lines changed

1 file changed

+35
-42
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
152152
auto b = TritonLLVMOpBuilder(loc, rewriter);
153153
assert(layout.getNumInDims() == indices.size());
154154
assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices)));
155-
// Trivial layout
156-
if (layout.getNumOutDims() == 0) {
157-
return {};
158-
}
159155

160156
// This function can emit a lot of MLIR code, which ultimately makes
161157
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -169,65 +165,62 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
169165
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
170166
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
171167
for (auto [inDimName, idx] : indices) {
172-
APInt constant;
173-
if (matchPattern(idx, m_ConstantInt(&constant))) {
174-
constantIns.push_back({inDimName, constant.getSExtValue()});
168+
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
169+
constantIns.push_back(
170+
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
175171
} else {
176172
constantIns.push_back({inDimName, 0});
177173
nonConstantIns.push_back({inDimName, idx});
178174
}
179175
}
176+
SmallVector<int32_t> constantComponent =
177+
llvm::to_vector(llvm::make_second_range(layout.apply(constantIns)));
180178

181-
// Compute constant part of the output and wrap it as values
182179
Value zero = b.i32_val(0);
183180
SmallVector<std::pair<StringAttr, Value>> outIndices;
184-
for (auto [outDimName, constant] : layout.apply(constantIns)) {
185-
if (constant == 0)
181+
for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) {
182+
if (constantComponent[i] == 0)
186183
outIndices.push_back({outDimName, zero});
187184
else
188-
outIndices.push_back({outDimName, b.i32_val(constant)});
189-
}
190-
191-
if (nonConstantIns.size() == 0) {
192-
return outIndices;
185+
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
193186
}
194-
195-
// Concatenate input
196-
Value x = b.i32_val(0);
197-
if (nonConstantIns.size() == 1) {
198-
x = nonConstantIns[0].second;
199-
} else {
187+
// Happy path: Only one output.
188+
if (outIndices.size() == 1) {
189+
SmallVector<StringAttr> inDimNames;
190+
// Concatenate input
191+
Value x = b.i32_val(0);
200192
int shift = 0;
201193
for (auto [inDimName, idx] : nonConstantIns) {
194+
inDimNames.push_back(inDimName);
202195
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
203196
shift += layout.getInDimSizeLog2(inDimName);
204197
}
198+
// Flatten ins
199+
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
200+
matrix = matrix.flattenIns();
201+
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
202+
outIndices[0].second = b.xor_(outIndices[0].second, out);
203+
return outIndices;
205204
}
206205

207-
// Remove constant input dims from the layout and flatten it
208-
auto inDimNames = llvm::to_vector(llvm::make_first_range(nonConstantIns));
209-
auto matrix = layout.sublayout(
210-
inDimNames, llvm::to_vector(llvm::make_first_range(outIndices)));
211-
auto flatMatrix = matrix.flattenIns().flattenOuts();
212-
213-
// Lower the matrix-vector product
214-
auto out = triton::gpu::matrixVectorProd(b, flatMatrix, x);
206+
for (auto [inDimName, idx] : indices) {
207+
if (idx.getDefiningOp<LLVM::ConstantOp>()) {
208+
continue;
209+
}
215210

216-
// Unpack the output
217-
if (matrix.getNumOutDims() == 1) {
218-
outIndices[0].second = b.xor_(outIndices[0].second, out);
219-
} else {
220-
assert(llvm::equal(matrix.getOutDimNames(),
221-
llvm::make_first_range(outIndices)));
222-
int shift = 0;
223-
for (auto &[dimName, outIdx] : outIndices) {
224-
auto outDimSizeLog2 = layout.getOutDimSizeLog2(dimName);
225-
auto mask = (1 << outDimSizeLog2) - 1;
226-
outIdx = b.xor_(outIdx,
227-
b.and_(b.lshr(out, b.i32_val(shift)), b.i32_val(mask)));
228-
shift += outDimSizeLog2;
211+
int nBits = layout.getInDimSizeLog2(inDimName);
212+
for (int i = 0; i < nBits; i++) {
213+
Value bit = b.and_(idx, b.i32_val(1 << i));
214+
Value bit_is_zero = b.icmp_eq(bit, zero);
215+
for (auto &[outDimName, outIdx] : outIndices) {
216+
int32_t basis = layout.getBasis(inDimName, i, outDimName);
217+
if (basis == 0)
218+
continue;
219+
outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis)));
220+
}
229221
}
230222
}
223+
231224
return outIndices;
232225
}
233226

0 commit comments

Comments
 (0)