Skip to content

Commit de8e715

Browse files
authored
[AMD] Optimize gfx9 wave id code generation (#8601)
On GFX9, this PR lifts computations of `wave_id` to the entry of the function and additionally emit `lvm.amdgcn.readfirstlane`. This gives us optimized code generation inside the loop.
1 parent 7025305 commit de8e715

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
187187
%arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) {
188188
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
189189
// The first constant 0 skips the LDS offset which is also 0
190-
// COMMON: llvm.getelementptr
191-
// COMMON: llvm.mlir.constant(0 : i32) : i32
192-
// COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
193-
// COMMON: llvm.mlir.constant(0 : i32) : i32
194-
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
190+
// COMMON: %[[VOFFSET:.*]] = llvm.select
191+
// COMMON-NEXT: %[[IMM0:.*]] = llvm.mlir.constant(0 : i32) : i32
192+
// COMMON-NEXT: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
193+
// COMMON-NEXT: %[[IMM1:.*]] = llvm.mlir.constant(0 : i32) : i32
194+
// COMMON-NEXT: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, %[[VOFFSET]], %[[IMM1]], %[[IMM0]], %[[aux_ca]]
195195
%1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
196196
// COMMON: llvm.getelementptr
197197
// COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32
@@ -328,3 +328,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
328328
tt.return
329329
}
330330
}
331+
332+
// -----
333+
334+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
335+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
336+
#smem = #ttg.shared_memory
337+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
338+
// COMMON-LABEL: buffer_load_to_local_wave_id
339+
tt.func public @buffer_load_to_local_wave_id(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
340+
%arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>, %arg3: i32) {
341+
// COMMON: %0 = rocdl.workitem.id.x : i32
342+
// COMMON-NEXT: %1 = llvm.mlir.constant(63 : i32) : i32
343+
// COMMON-NEXT: %2 = llvm.and %0, %1 : i32
344+
// COMMON-NEXT: %3 = llvm.mlir.constant(64 : i32) : i32
345+
// COMMON-NEXT: %4 = llvm.mlir.constant(0 : i32) : i32
346+
// COMMON-NEXT: %5 = llvm.call_intrinsic "llvm.amdgcn.readfirstlane"(%4) : (i32) -> i32
347+
// COMMON-NEXT: %6 = rocdl.workitem.id.x : i32
348+
// COMMON-NEXT: %7 = llvm.mlir.constant(63 : i32) : i32
349+
// COMMON-NEXT: %8 = llvm.and %6, %7 : i32
350+
// COMMON-NEXT: %9 = llvm.mlir.constant(64 : i32) : i32
351+
// COMMON-NEXT: %10 = llvm.mlir.constant(0 : i32) : i32
352+
// COMMON-NEXT: %11 = llvm.call_intrinsic "llvm.amdgcn.readfirstlane"(%10) : (i32) -> i32
353+
354+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
355+
%1 = amdgpu.buffer_load_to_local %arg0[%0] into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
356+
%c0_i32 = arith.constant 0 : i32
357+
%cond = llvm.icmp "eq" %arg3, %c0_i32 : i32
358+
cf.cond_br %cond, ^bb1, ^bb2
359+
^bb1:
360+
amdgpu.buffer_load_to_local %arg0[%0] into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
361+
cf.br ^bb1
362+
^bb2:
363+
tt.return
364+
}
365+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
482482
void lowerDirectToLDSLoad(
483483
RewriterBase &rewriter, Location loc, RankedTensorType srcTy,
484484
MemDescType dstTy, SmallVector<Value> loadVals, Value llDst,
485-
Type resElemTy, unsigned vec,
485+
Type resElemTy, unsigned vec, triton::AMD::ISAFamily isaFamily,
486486
std::function<SmallVector<Value>(RewriterBase &, Location,
487487
ArrayRef<Value>, Value, int, VectorType)>
488488
lowerInst) const {
@@ -511,7 +511,40 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
511511
LLVM::getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter);
512512
auto affineOffset = smemObj.getShmemOffset(loc, rewriter, dstTy);
513513
auto maskSpanAffineOffset = SharedMemoryObject::getMaskSpanOffsets(dstTy);
514-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
514+
515+
Value laneId, warpId;
516+
if (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily) {
517+
// On GFX9, there is no dedicated hardware instruction to read `wave_id`.
518+
// The value is instead computed from `workitem.id.x`. Per the GFX9 ABI,
519+
// `workitem.id.x` is initialized in a vector register, and vector
520+
// instructions are generated for IR operations that depend on `wave_id`.
521+
//
522+
// A `v_readfirstlane` instruction is inserted at the end of these vector
523+
// sequences to transfer the value from a vector register to a scalar
524+
// register, initializing `$m0`.
525+
526+
// When this sequence occurs inside a loop, the MachineLICM pass does not
527+
// hoist it because `v_readfirstlane` is convergent. Since both
528+
// `workitem.id.x` and `wave_id` are constant at runtime, their
529+
// computation can be safely hoisted to the function entry block.
530+
auto insertPt = rewriter.saveInsertionPoint();
531+
Operation *parentOp = insertPt.getBlock()->getParentOp();
532+
while (!isa<LLVM::LLVMFuncOp>(parentOp)) {
533+
parentOp = parentOp->getParentOp();
534+
}
535+
536+
auto funcOp = cast<LLVM::LLVMFuncOp>(parentOp);
537+
rewriter.setInsertionPointToStart(&funcOp.getBody().front());
538+
539+
std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc);
540+
auto call = LLVM::createLLVMIntrinsicCallOp(
541+
rewriter, loc, "llvm.amdgcn.readfirstlane", {i32_ty}, {warpId});
542+
warpId = call.getResult(0);
543+
rewriter.restoreInsertionPoint(insertPt);
544+
} else {
545+
std::tie(laneId, warpId) = getLaneAndWarpId(rewriter, loc);
546+
}
547+
515548
auto calcPaddedOffset = [&](Value smemOffset) {
516549
TritonLLVMOpBuilder b(loc, rewriter);
517550
auto bitwidth = dstTy.getElementTypeBitWidth();
@@ -873,7 +906,8 @@ struct BufferLoadToLocalOpConversion
873906
};
874907

875908
lowerDirectToLDSLoad(rewriter, loc, ptrType, flatDstTy, loadVals, llDst,
876-
resElemTy, vec, emitBufferLoadLds);
909+
resElemTy, vec, targetInfo.getISAFamily(),
910+
emitBufferLoadLds);
877911

878912
// Drop the result token.
879913
Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
@@ -999,7 +1033,8 @@ struct AsyncCopyGlobalToLocalOpConversion
9991033
};
10001034

10011035
lowerDirectToLDSLoad(rewriter, loc, srcTy, flatDstTy, loadVals, llDst,
1002-
resElemTy, vec, emitGlobalLoadLds);
1036+
resElemTy, vec, targetInfo.getISAFamily(),
1037+
emitGlobalLoadLds);
10031038

10041039
// Drop the result token.
10051040
Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),

0 commit comments

Comments
 (0)