Skip to content

Commit 89c0b0a

Browse files
[AMD] Introduce amdgpu.cond_barrier (#5360)
condBarrierOp sets barrier instruction only when the given argument is true. This provides a way to synchronize partial threads in a block, deliberately diverges the execution sequences of the threads but still in the sync. However, user should guarantee all threads converge at the end by calling condBarrierOp(true) with the remaining threads. Conceptually, this is similar to having a barrier inside an if statement. This op allows us to avoid blocking the whole block when suitable to help scheduling. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 5da85b1 commit 89c0b0a

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
4+
tt.func @conditional_barrier() {
5+
// CHECK-LABEL: llvm.func @conditional_barrier
6+
7+
// CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32
8+
// CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32
9+
// CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2
10+
// CHECK: ^bb1:
11+
// CHECK: rocdl.s.barrier
12+
// CHECK: llvm.br ^bb2
13+
// CHECK: ^bb2:
14+
// CHECK: llvm.add
15+
// CHECK: llvm.cond_br %[[CMP1]], ^bb3, ^bb4
16+
// CHECK: ^bb3:
17+
// CHECK: rocdl.s.barrier
18+
// CHECK: llvm.br ^bb4
19+
// CHECK: ^bb4:
20+
// CHECK: llvm.return
21+
22+
%c256_i32 = arith.constant 256 : i32
23+
%c0_i32 = arith.constant 0 : i32
24+
%0 = rocdl.workitem.id.x : i32
25+
%1 = arith.divsi %0, %c256_i32 : i32
26+
%2 = arith.cmpi ne, %1, %c0_i32 : i32
27+
%3 = arith.cmpi eq, %1, %c0_i32 : i32
28+
amdgpu.cond_barrier %2
29+
%4 = arith.addi %0, %c256_i32 : i32
30+
amdgpu.cond_barrier %3
31+
tt.return
32+
}
33+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
152152
let assemblyFormat = [{ attr-dict }];
153153
}
154154

155+
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
156+
Arguments<(ins I1:$pred)> {
157+
let summary = "Conditionally set barriers to synchronize partial threads in a block";
158+
159+
let description = [{
160+
condBarrierOp sets barrier instruction only when the given argument is true.
161+
This provides a way to synchronize partial threads in a block, deliberately
162+
diverges the execution sequences. However, user should guarantee all threads
163+
converge at the end by calling condBarrierOp(true) with the remaining threads.
164+
Conceptually, this is similar to having an execution barrier inside an if statement.
165+
This op allows us to avoid blocking the whole block when suitable to help scheduling.
166+
NB. This doesn't set any memory fence.
167+
}];
168+
169+
let assemblyFormat = "$pred attr-dict";
170+
}
171+
155172
//
156173
// AMD Buffer operations.
157174
//

third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
12
#include "PatternTritonGPUOpToLLVM.h"
23
#include "Utility.h"
4+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
35

46
using namespace mlir;
57

@@ -25,10 +27,37 @@ struct GetNumProgramsOpConversion
2527
}
2628
};
2729

30+
struct CondBarrierOpConversion
31+
: public ConvertOpToLLVMPattern<triton::amdgpu::CondBarrierOp> {
32+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
33+
34+
LogicalResult
35+
matchAndRewrite(triton::amdgpu::CondBarrierOp op, OpAdaptor adaptor,
36+
ConversionPatternRewriter &rewriter) const override {
37+
Location loc = op->getLoc();
38+
Block *currentBlock = rewriter.getInsertionBlock();
39+
Block *afterCondBarBlock =
40+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
41+
Block *trueBlock = rewriter.createBlock(afterCondBarBlock);
42+
rewriter.setInsertionPointToEnd(currentBlock);
43+
rewriter.create<LLVM::CondBrOp>(loc, adaptor.getPred(), trueBlock,
44+
afterCondBarBlock);
45+
46+
// conditional barrier
47+
rewriter.setInsertionPointToStart(trueBlock);
48+
rewriter.create<ROCDL::SBarrierOp>(loc);
49+
rewriter.create<LLVM::BrOp>(loc, afterCondBarBlock);
50+
51+
rewriter.eraseOp(op);
52+
return success();
53+
}
54+
};
55+
2856
} // namespace
2957

3058
void mlir::triton::AMD::populateSPMDOpToLLVMPattern(
3159
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3260
PatternBenefit benefit) {
3361
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
62+
patterns.add<CondBarrierOpConversion>(typeConverter, benefit);
3463
}

0 commit comments

Comments
 (0)