Skip to content

Commit 89234eb

Browse files
lezcanowhitneywhtsang
authored andcommitted
Partially Revert "[LAYOUTS] Enable diagonal iteration unconditionally (#7218)" (#7245)
We are seeing some internal regressions. This reverts commit 336cc1d. (cherry picked from commit c8a711d)
1 parent fe79064 commit 89234eb

File tree

1 file changed

+35
-42
lines changed

1 file changed

+35
-42
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151151
auto b = TritonLLVMOpBuilder(loc, rewriter);
152152
assert(layout.getNumInDims() == indices.size());
153153
assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices)));
154-
// Trivial layout
155-
if (layout.getNumOutDims() == 0) {
156-
return {};
157-
}
158154

159155
// This function can emit a lot of MLIR code, which ultimately makes
160156
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -168,65 +164,62 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
168164
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
169165
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
170166
for (auto [inDimName, idx] : indices) {
171-
APInt constant;
172-
if (matchPattern(idx, m_ConstantInt(&constant))) {
173-
constantIns.push_back({inDimName, constant.getSExtValue()});
167+
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
168+
constantIns.push_back(
169+
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
174170
} else {
175171
constantIns.push_back({inDimName, 0});
176172
nonConstantIns.push_back({inDimName, idx});
177173
}
178174
}
175+
SmallVector<int32_t> constantComponent =
176+
llvm::to_vector(llvm::make_second_range(layout.apply(constantIns)));
179177

180-
// Compute constant part of the output and wrap it as values
181178
Value zero = b.i32_val(0);
182179
SmallVector<std::pair<StringAttr, Value>> outIndices;
183-
for (auto [outDimName, constant] : layout.apply(constantIns)) {
184-
if (constant == 0)
180+
for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) {
181+
if (constantComponent[i] == 0)
185182
outIndices.push_back({outDimName, zero});
186183
else
187-
outIndices.push_back({outDimName, b.i32_val(constant)});
188-
}
189-
190-
if (nonConstantIns.size() == 0) {
191-
return outIndices;
184+
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
192185
}
193-
194-
// Concatenate input
195-
Value x = b.i32_val(0);
196-
if (nonConstantIns.size() == 1) {
197-
x = nonConstantIns[0].second;
198-
} else {
186+
// Happy path: Only one output.
187+
if (outIndices.size() == 1) {
188+
SmallVector<StringAttr> inDimNames;
189+
// Concatenate input
190+
Value x = b.i32_val(0);
199191
int shift = 0;
200192
for (auto [inDimName, idx] : nonConstantIns) {
193+
inDimNames.push_back(inDimName);
201194
x = b.or_(x, b.shl(idx, b.i32_val(shift)));
202195
shift += layout.getInDimSizeLog2(inDimName);
203196
}
197+
// Flatten ins
198+
auto matrix = layout.sublayout(inDimNames, outIndices[0].first);
199+
matrix = matrix.flattenIns();
200+
auto out = triton::gpu::matrixVectorProd(b, matrix, x);
201+
outIndices[0].second = b.xor_(outIndices[0].second, out);
202+
return outIndices;
204203
}
205204

206-
// Remove constant input dims from the layout and flatten it
207-
auto inDimNames = llvm::to_vector(llvm::make_first_range(nonConstantIns));
208-
auto matrix = layout.sublayout(
209-
inDimNames, llvm::to_vector(llvm::make_first_range(outIndices)));
210-
auto flatMatrix = matrix.flattenIns().flattenOuts();
211-
212-
// Lower the matrix-vector product
213-
auto out = triton::gpu::matrixVectorProd(b, flatMatrix, x);
205+
for (auto [inDimName, idx] : indices) {
206+
if (idx.getDefiningOp<LLVM::ConstantOp>()) {
207+
continue;
208+
}
214209

215-
// Unpack the output
216-
if (matrix.getNumOutDims() == 1) {
217-
outIndices[0].second = b.xor_(outIndices[0].second, out);
218-
} else {
219-
assert(llvm::equal(matrix.getOutDimNames(),
220-
llvm::make_first_range(outIndices)));
221-
int shift = 0;
222-
for (auto &[dimName, outIdx] : outIndices) {
223-
auto outDimSizeLog2 = layout.getOutDimSizeLog2(dimName);
224-
auto mask = (1 << outDimSizeLog2) - 1;
225-
outIdx = b.xor_(outIdx,
226-
b.and_(b.lshr(out, b.i32_val(shift)), b.i32_val(mask)));
227-
shift += outDimSizeLog2;
210+
int nBits = layout.getInDimSizeLog2(inDimName);
211+
for (int i = 0; i < nBits; i++) {
212+
Value bit = b.and_(idx, b.i32_val(1 << i));
213+
Value bit_is_zero = b.icmp_eq(bit, zero);
214+
for (auto &[outDimName, outIdx] : outIndices) {
215+
int32_t basis = layout.getBasis(inDimName, i, outDimName);
216+
if (basis == 0)
217+
continue;
218+
outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis)));
219+
}
228220
}
229221
}
222+
230223
return outIndices;
231224
}
232225

0 commit comments

Comments
 (0)