Skip to content

Commit 00d5ca7

Browse files
authored
[NFC] Refactor XOR trick into helper function (#7397)
There are several places in the codebase where we use the trick introduced in PR #4213 of calling `applyLinearLayout` once with the lane/warp/block indices and then XORing with many values (one for each register) obtained from `ll.apply`. We begin by refactoring it out of the `emitIndices` function and will proceed to address the other places that this is used in the codebase. Remarkably, the whole commit was written by GPT-4.5, requiring only very minimal feedback.
1 parent 80449c2 commit 00d5ca7

File tree

1 file changed

+68
-32
lines changed

1 file changed

+68
-32
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,53 @@ Value getLaneId(OpBuilder &rewriter, Location loc) {
306306
return getLaneAndWarpId(rewriter, loc).first;
307307
}
308308

309+
// Helper function: applies linear layout vectorized over register indices
310+
SmallVector<SmallVector<std::pair<StringAttr, Value>>>
311+
applyLinearLayoutVec(Location loc, RewriterBase &rewriter,
312+
const LinearLayout &layout,
313+
ArrayRef<std::pair<StringAttr, Value>> indices,
314+
ArrayRef<uint32_t> registers) {
315+
auto b = TritonLLVMOpBuilder(loc, rewriter);
316+
MLIRContext *ctx = rewriter.getContext();
317+
318+
StringAttr kRegister = str_attr("register");
319+
320+
// Precompute the base (with register = 0)
321+
SmallVector<std::pair<StringAttr, Value>> indicesWithZeroReg;
322+
for (const auto &[attr, val] : indices) {
323+
if (attr == kRegister)
324+
indicesWithZeroReg.emplace_back(attr, b.i32_val(0));
325+
else
326+
indicesWithZeroReg.emplace_back(attr, val);
327+
}
328+
329+
auto baseIndices =
330+
applyLinearLayout(loc, rewriter, layout, indicesWithZeroReg);
331+
332+
SmallVector<SmallVector<std::pair<StringAttr, Value>>> ret;
333+
334+
// Iterate over registers, applying XOR trick
335+
for (auto reg : registers) {
336+
SmallVector<std::pair<StringAttr, int32_t>> constRegIndices;
337+
for (const auto &[attr, val] : indices) {
338+
constRegIndices.emplace_back(attr, attr == kRegister ? reg : 0);
339+
}
340+
auto regIndices = layout.apply(constRegIndices);
341+
342+
SmallVector<std::pair<StringAttr, Value>> combinedIndices;
343+
for (auto [base, regIdx] : llvm::zip(baseIndices, regIndices)) {
344+
assert(base.first == regIdx.first);
345+
Value combined = b.xor_(base.second, b.i32_val(regIdx.second));
346+
combinedIndices.emplace_back(base.first, combined);
347+
}
348+
349+
ret.push_back(combinedIndices);
350+
}
351+
352+
return ret;
353+
}
354+
355+
// Refactored emitIndices function using applyLinearLayoutVec
309356
SmallVector<SmallVector<Value>>
310357
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
311358
Attribute layout, RankedTensorType type, bool withCTAOffset) {
@@ -315,8 +362,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
315362

316363
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
317364

318-
// TODO(jlebar): We could add strong typing if we wanted; for now this is
319-
// "stringly typed".
320365
StringAttr kRegister = str_attr("register");
321366
StringAttr kLane = str_attr("lane");
322367
StringAttr kWarp = str_attr("warp");
@@ -325,38 +370,29 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
325370
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
326371
Value blockId =
327372
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
373+
374+
SmallVector<std::pair<StringAttr, Value>> commonIndices = {
375+
{kRegister, b.i32_val(0)},
376+
{kLane, laneId},
377+
{kWarp, warpId},
378+
{kBlock, blockId}};
379+
380+
// Vectorize over registers
381+
SmallVector<uint32_t> registerIndices;
382+
for (unsigned reg = 0; reg < ll.getInDimSize(kRegister); ++reg)
383+
registerIndices.push_back(reg);
384+
385+
auto vecIndices =
386+
applyLinearLayoutVec(loc, rewriter, ll, commonIndices, registerIndices);
387+
328388
unsigned rank = shape.size();
329389
SmallVector<SmallVector<Value>> ret;
330-
// Linear layout function is split in two parts below:
331-
// L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
332-
// idxs = idxsBase xor idxsReg
333-
//
334-
// L(0, t, w, b) part is the same for all registers,
335-
// so we hoist it out of the main register loop in the below.
336-
//
337-
// This approach produces code with lower register pressure and
338-
// less computations, compared to fused L(r,t,w,b) method.
339-
auto idxsBase = applyLinearLayout(loc, rewriter, ll,
340-
{{kRegister, b.i32_val(0)},
341-
{kLane, laneId},
342-
{kWarp, warpId},
343-
{kBlock, blockId}});
344-
for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) {
345-
auto idxsReg =
346-
ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
347-
SmallVector<std::pair<StringAttr, Value>> idxs;
348-
for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) {
349-
auto dimName = idxBase.first;
350-
assert(dimName == idxReg.first &&
351-
"dim names of block+warp+thread and register idx should be equal");
352-
auto idx = b.xor_(idxBase.second, b.i32_val(idxReg.second));
353-
idxs.emplace_back(dimName, idx);
354-
}
355-
assert(idxs.size() == rank);
356-
for (unsigned k = 0; k < rank; ++k) {
357-
assert(idxs[k].first == str_attr("dim" + std::to_string(k)));
358-
}
359-
ret.push_back(llvm::to_vector(llvm::make_second_range(idxs)));
390+
for (auto &indices : vecIndices) {
391+
SmallVector<Value> vals;
392+
assert(indices.size() == rank);
393+
for (auto &idx : indices)
394+
vals.push_back(idx.second);
395+
ret.push_back(vals);
360396
}
361397

362398
return ret;

0 commit comments

Comments
 (0)