Add alloc_warp_barrier for multi-thread barrier arrival#1031
Add alloc_warp_barrier for multi-thread barrier arrival#1031tissue3 wants to merge 1 commit intofacebookexperimental:mainfrom
alloc_warp_barrier for multi-thread barrier arrival#1031Conversation
alloc_warp_barrier for multi-thread barrier arrival
njriasan
left a comment
There was a problem hiding this comment.
@tissue3 This is a great initial implementation. However, I'm concerned that we aren't actually eliminating the synchronization and we are only eliminating the predicate.
Also my intuition is that this this may have different impacts per shape depending on which partition is he bottleneck. Once we have confirmed the sync no longer occurs, let's expand this to a broader set of model shapes and run with TritonBench.
| tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) | ||
| tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) | ||
| if USE_WARP_BARRIER: | ||
| tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=1) |
There was a problem hiding this comment.
If after verifying the sync looks as expected we see a gain, but only a minor one (say < 1%), we might want to try applying this to all of the barriers in the kernel. That will of course be more work to support TMA.
| ttgir = kernel.asm["ttgir"] | ||
| assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}" |
There was a problem hiding this comment.
I think these tests need to dump the llir to verify the lowering pattern is as expected.
There was a problem hiding this comment.
Yeah, or ptx? Especially to verify the barrier sync isn't there.
| } | ||
|
|
||
| // TODO: Add phase result as needed. | ||
| std::stringstream ptxAsm; |
There was a problem hiding this comment.
Can we generate a lit test for the associated IR after lowering? I think we need to inspect the code lowering prepared for LLVM.
If I understand correctly the barrier is only half the issue and we also need to ensure the synchronization before it is removed. This is the an example full IR from AutoWS GEMM. Not 1:1 but should contain a similar pattern that we care about (and I have it available) .
This is my relevant code for the section that goes TMEM_LOAD -> barrier arrive:
nvvm.tcgen05.wait <load> loc(#loc53)
nvvm.barrier0 loc(#loc53)
%accumulator_1092 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc53)
%accumulator_1093 = llvm.sub %accumulator_1092, %102 : i32 loc(#loc53)
%accumulator_1094 = llvm.and %accumulator_1093, %94 : i32 loc(#loc53)
%accumulator_1095 = llvm.icmp "eq" %accumulator_1094, %109 : i32 loc(#loc53)
%accumulator_1096 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 mbarrier.arrive.shared::cta.b64 _, [$1];", "b,r" %accumulator_1095, %121 : (i1, !llvm.struct<(ptr<3>, i32)>) -> !llvm.void loc(#loc53)
As I understand it, this nvvm.barrier0 is responsible for the sync. So even though you example will eliminate accumulator_1092-accumulator_1095 and simplify accumulator_1096, to achieve the desired impact we will need to remove this barrier (or possibly reduce it to a sync between warps and not between threads).
cc: @htyu in case you have looked at this in more detail.
There was a problem hiding this comment.
I guess the membar pass is the culprit. It's designed to protect shared memory access with bar.sync to makes sure all threads see exactly same data. Barriers could be an exception.
There was a problem hiding this comment.
@njriasan @htyu Thanks for reviewing! I tried removing that nvvm.barrier but still got trivial perf difference
matmul-performance-fp16:
M N K cuBLAS ws ws_warp_barrier clc clc_warp_barrier
0 2048.0 2048.0 2048.0 838.860821 453.055644 453.055644 623.543467 622.098396
1 4096.0 4096.0 4096.0 1109.237446 1090.923891 1090.923891 1040.195516 1040.447562
2 8192.0 8192.0 8192.0 1114.996704 1106.807677 1104.601605 920.826955 921.753365
Is it because I only enabled per-thread sync, not per buffer?
There was a problem hiding this comment.
It's fine with neutral perf. Maybe other cases can benefit from this.
Can you check if some bar.sync ops go away on the PTX?
|
Thanks for working on this!
Right, the I also wonder if some validation pass against |
| // For perThread arrives, each thread's own program order guarantees its | ||
| // SMEM ops complete before its arrive, and the mbarrier accumulates all | ||
| // arrivals before releasing the waiter, so no CTA-wide fence is needed. | ||
| if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op)) { |
There was a problem hiding this comment.
Looks good. Does it affect performance?
| ttgir = kernel.asm["ttgir"] | ||
| assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}" |
There was a problem hiding this comment.
Yeah, or ptx? Especially to verify the barrier sync isn't there.
Summary:
Add
alloc_warp_barrierto TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (__syncwarp+ thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence.Infrastructure changes:
ArriveBarrierOpinTritonNvidiaGPUOps.td: addperThreadUnitAttr and new builder overloadBarrierOpToLLVM.cpp: whenperThreadis set, emitmbarrier.arrivewithout the leader-thread predicate (nothreadIdx == 0check), so all threads arrive independentlybarrier.py: addalloc_warp_barrier(num_barriers, num_warps, num_arrivals)which setsarrive_count = num_warps * 32 * num_arrivalsand marks the barrier as a warp barrier;barrier_arrivedispatches tocreate_warp_barrier_arrivefor warp barrierstypes.py/mem_ops.py: propagateis_warp_barrierflag throughmbarrier/mbarrier_typeandlocal_viewtriton_tlx.cc: exposecreate_warp_barrier_arriveto the dialect builderTest Plan:
Performance:
GEMM shows no significant difference between
alloc_barriersandalloc_warp_barrier— results are within noise.