Skip to content

Commit 345c633

Browse files
authored
[Backend] Check if wait_barrier is a constant true (#7508)
When the `wait_barrier` predicate ends up being true, either generated from Gluon or for some other optimization, generate the non-predicated version of the instruction, which is slightly more efficient due to 1 less branch. I observed a speedup of like 3 TFLOPS on the attention kernel because of this. It was stable and reproducible and not just noise.
1 parent 4a8277b commit 345c633

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1717
#smem = #ttg.shared_memory
1818
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1919
// CHECK-LABEL: wait_barrier
20-
tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %phase: i32) {
20+
tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %phase: i32, %pred: i1) {
2121
// CHECK: waitLoop:
2222
// CHECK: mbarrier.try_wait.parity.shared.b64
2323
// CHECK: @!complete bra.uni waitLoop
24+
// CHECK-NOT: skipWait
25+
// CHECK: %{{[0-9]+}}, %arg1 :
2426
ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0, #smem>
27+
%true = arith.constant true
28+
29+
// CHECK: waitLoop:
30+
// CHECK: mbarrier.try_wait.parity.shared.b64
31+
// CHECK: @!complete bra.uni waitLoop
32+
// CHECK-NOT: skipWait
33+
// CHECK: %{{[0-9]+}}, %arg1 :
34+
ttng.wait_barrier %alloc, %phase, %true : !ttg.memdesc<1xi64, #shared0, #smem>
35+
36+
// CHECK: @!$2 bra.uni skipWait
37+
// CHECK: waitLoop:
38+
// CHECK: mbarrier.try_wait.parity.shared.b64
39+
// CHECK: @!complete bra.uni waitLoop
40+
// CHECK: skipWait:
41+
// CHECK: %{{[0-9]+}}, %arg1, %arg2 :
42+
ttng.wait_barrier %alloc, %phase, %pred : !ttg.memdesc<1xi64, #shared0, #smem>
2543
tt.return
2644
}
2745

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ struct WaitBarrierOpConversion
160160
typeConverter->convertType(op.getAlloc().getType().getElementType()),
161161
rewriter);
162162
auto loc = op.getLoc();
163-
bool predicated = adaptor.getPred() != nullptr;
163+
bool predicated =
164+
adaptor.getPred() && !matchPattern(op.getPred(), m_NonZero());
164165
std::string ptx;
165166
if (targetInfo->getComputeCapability() < 90) {
166167
if (!predicated) {

0 commit comments

Comments
 (0)