Skip to content

Commit 4a119a3

Browse files
apgoucherlezcano
andauthored
[NFC] Partially Revert "Partially Revert "[LAYOUTS] Enable diagonal iteration unconditionally (#7218)" (#7245)" (#7299)
This restores some of the code cleanliness improvements from 7218 without affecting functionality. Largely this is to keep a smaller bisectable diff to study the benchmark impacts of relanding 7218. There is a slight functional change here -- we have the more liberal pattern-matching approach introduced by 7218 -- but I've tested on the internal benchmarks and it doesn't affect perf. Co-authored-by: Mario Lezcano Casado <[email protected]>
1 parent 863e42f commit 4a119a3

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
154154
auto b = TritonLLVMOpBuilder(loc, rewriter);
155155
assert(layout.getNumInDims() == indices.size());
156156
assert(llvm::equal(layout.getInDimNames(), llvm::make_first_range(indices)));
157+
// Trivial layout
158+
if (layout.getNumOutDims() == 0) {
159+
return {};
160+
}
157161

158162
// This function can emit a lot of MLIR code, which ultimately makes
159163
// compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -167,25 +171,29 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
167171
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
168172
SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
169173
for (auto [inDimName, idx] : indices) {
170-
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
171-
constantIns.push_back(
172-
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
174+
APInt constant;
175+
if (matchPattern(idx, m_ConstantInt(&constant))) {
176+
constantIns.push_back({inDimName, constant.getSExtValue()});
173177
} else {
174178
constantIns.push_back({inDimName, 0});
175179
nonConstantIns.push_back({inDimName, idx});
176180
}
177181
}
178-
SmallVector<int32_t> constantComponent =
179-
llvm::to_vector(llvm::make_second_range(layout.apply(constantIns)));
180182

183+
// Compute constant part of the output and wrap it as values
181184
Value zero = b.i32_val(0);
182185
SmallVector<std::pair<StringAttr, Value>> outIndices;
183-
for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) {
184-
if (constantComponent[i] == 0)
186+
for (auto [outDimName, constant] : layout.apply(constantIns)) {
187+
if (constant == 0)
185188
outIndices.push_back({outDimName, zero});
186189
else
187-
outIndices.push_back({outDimName, b.i32_val(constantComponent[i])});
190+
outIndices.push_back({outDimName, b.i32_val(constant)});
191+
}
192+
193+
if (nonConstantIns.size() == 0) {
194+
return outIndices;
188195
}
196+
189197
// Happy path: Only one output.
190198
if (outIndices.size() == 1) {
191199
SmallVector<StringAttr> inDimNames;

0 commit comments

Comments
 (0)