[TLX] Add alloc_warp_barrier for multi-thread barrier arrival#1031
[TLX] Add alloc_warp_barrier for multi-thread barrier arrival#1031tissue3 wants to merge 1 commit intofacebookexperimental:mainfrom
alloc_warp_barrier for multi-thread barrier arrival#1031Conversation
alloc_warp_barrier for multi-thread barrier arrival
njriasan
left a comment
There was a problem hiding this comment.
@tissue3 This is a great initial implementation. However, I'm concerned that we aren't actually eliminating the synchronization and we are only eliminating the predicate.
Also my intuition is that this this may have different impacts per shape depending on which partition is he bottleneck. Once we have confirmed the sync no longer occurs, let's expand this to a broader set of model shapes and run with TritonBench.
| tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) | ||
| tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) | ||
| if USE_WARP_BARRIER: | ||
| tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=1) |
There was a problem hiding this comment.
If after verifying the sync looks as expected we see a gain, but only a minor one (say < 1%), we might want to try applying this to all of the barriers in the kernel. That will of course be more work to support TMA.
| ttgir = kernel.asm["ttgir"] | ||
| assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}" |
There was a problem hiding this comment.
I think these tests need to dump the llir to verify the lowering pattern is as expected.
There was a problem hiding this comment.
Yeah, or ptx? Especially to verify the barrier sync isn't there.
| } | ||
|
|
||
| // TODO: Add phase result as needed. | ||
| std::stringstream ptxAsm; |
There was a problem hiding this comment.
Can we generate a lit test for the associated IR after lowering? I think we need to inspect the code lowering prepared for LLVM.
If I understand correctly the barrier is only half the issue and we also need to ensure the synchronization before it is removed. This is the an example full IR from AutoWS GEMM. Not 1:1 but should contain a similar pattern that we care about (and I have it available) .
This is my relevant code for the section that goes TMEM_LOAD -> barrier arrive:
nvvm.tcgen05.wait <load> loc(#loc53)
nvvm.barrier0 loc(#loc53)
%accumulator_1092 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc53)
%accumulator_1093 = llvm.sub %accumulator_1092, %102 : i32 loc(#loc53)
%accumulator_1094 = llvm.and %accumulator_1093, %94 : i32 loc(#loc53)
%accumulator_1095 = llvm.icmp "eq" %accumulator_1094, %109 : i32 loc(#loc53)
%accumulator_1096 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 mbarrier.arrive.shared::cta.b64 _, [$1];", "b,r" %accumulator_1095, %121 : (i1, !llvm.struct<(ptr<3>, i32)>) -> !llvm.void loc(#loc53)
As I understand it, this nvvm.barrier0 is responsible for the sync. So even though you example will eliminate accumulator_1092-accumulator_1095 and simplify accumulator_1096, to achieve the desired impact we will need to remove this barrier (or possibly reduce it to a sync between warps and not between threads).
cc: @htyu in case you have looked at this in more detail.
There was a problem hiding this comment.
I guess the membar pass is the culprit. It's designed to protect shared memory access with bar.sync to makes sure all threads see exactly same data. Barriers could be an exception.
There was a problem hiding this comment.
@njriasan @htyu Thanks for reviewing! I tried removing that nvvm.barrier but still got trivial perf difference
matmul-performance-fp16:
M N K cuBLAS ws ws_warp_barrier clc clc_warp_barrier
0 2048.0 2048.0 2048.0 838.860821 453.055644 453.055644 623.543467 622.098396
1 4096.0 4096.0 4096.0 1109.237446 1090.923891 1090.923891 1040.195516 1040.447562
2 8192.0 8192.0 8192.0 1114.996704 1106.807677 1104.601605 920.826955 921.753365
Is it because I only enabled per-thread sync, not per buffer?
There was a problem hiding this comment.
It's fine with neutral perf. Maybe other cases can benefit from this.
Can you check if some bar.sync ops go away on the PTX?
There was a problem hiding this comment.
It's fine with neutral perf. Maybe other cases can benefit from this.
Can you check if some bar.sync ops go away on the PTX?
The current implementation is the sync before mbarrier.arrive.shared is removed: https://fburl.com/phabricator/lz64ljio, but not other syncs
|
Thanks for working on this!
Right, the I also wonder if some validation pass against |
lib/Analysis/Membar.cpp
Outdated
| // For perThread arrives, each thread's own program order guarantees its | ||
| // SMEM ops complete before its arrive, and the mbarrier accumulates all | ||
| // arrivals before releasing the waiter, so no CTA-wide fence is needed. | ||
| if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op)) { |
There was a problem hiding this comment.
Looks good. Does it affect performance?
| ttgir = kernel.asm["ttgir"] | ||
| assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}" | ||
| assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}" |
There was a problem hiding this comment.
Yeah, or ptx? Especially to verify the barrier sync isn't there.
8ab9f85 to
2dcba0c
Compare
|
nit: please prefix the diff tile with "[TLX]" to ease rebasing and cherry-picking. |
alloc_warp_barrier for multi-thread barrier arrivalalloc_warp_barrier for multi-thread barrier arrival
efb6525 to
18d9ea6
Compare
njriasan
left a comment
There was a problem hiding this comment.
LGTM! A couple suggestions for possible followups depending on what we measure (or don't). Thanks for accomplishing this so quickly!
| if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op)) | ||
| isPerThreadArrive = arriveOp.getPerThread(); | ||
|
|
||
| if (!isPerThreadArrive) { |
There was a problem hiding this comment.
Let's make sure we run an accuracy test as well. I'm not sure if the general case needs to be conservative and ensure with multiple warps there is a risk that the same warp can contribute multiple arrivals. However, this practically shouldn't happen, so worst case we can add a TODO to test.
There was a problem hiding this comment.
I did do some accuracy test (see my test plan), but I am not sure if that is enough. Please let me know if there is anything other than these tests to check on
pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws_warp_barrier
pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_clc_warp_barrier
| tmem_empty_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, num_warps=4, | ||
| num_arrivals=EPILOGUE_SUBTILE) | ||
| else: | ||
| tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1) |
There was a problem hiding this comment.
Maybe we should try disabling it for tmem_full_bars? That has the MMA operation signaling it, so its probably not impacted by this code.
Alternatively before we definitely say "perf neutral" maybe we should try migrating every barrier to use this (although that might require more work for MMA/TMA) since only 1 thread is potentially issuing the instruction.
There was a problem hiding this comment.
Adding barrier for smem is bit tricky since it involves hardware change. Here is some claude's analysis:
Can we convert to warp barriers?
smem_empty_bars: No
The tcgen05.commit instruction uses .mbarrier::arrive::one — this is baked into the hardware. The MMA engine always contributes exactly 1 arrival when it finishes reading SMEM. There's no PTX variant that contributes num_warps * 32 arrivals. This is fundamentally incompatible with warp barriers.
smem_full_bars: Theoretically possible but messy
The mbarrier.arrive.expect_tx instruction inherently does two things atomically:
- Sets the expected TX byte count
- Counts as 1 arrival
There's no PTX instruction to set TX bytes without also arriving. So if we wanted a warp barrier:
- Init with arrive_count = 33 (1 from expect_tx + 32 from warp arrive, assuming num_warps=1)
- Thread 0: barrier_expect_bytes → mbarrier.arrive.expect_tx (1 arrive + set TX)
- All 32 threads: barrier_arrive as warp barrier → mbarrier.arrive (32 arrives)
- Total: 33 ✓
But this requires:
- A custom arrive_count that doesn't fit the alloc_warp_barrier API (num_warps * 32 formula doesn't account for the extra
expect_tx arrive) - Adding an explicit barrier_arrive call after barrier_expect_bytes in the kernel
- Making sure the is_warp_barrier attribute is set despite not using alloc_warp_barrier
Let me know your thoughts and suggestions!
570faaa to
471a101
Compare
Summary: Add `alloc_warp_barrier` to TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (`__syncwarp` + thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence. Infrastructure changes: - `ArriveBarrierOp` in `TritonNvidiaGPUOps.td`: add `perThread` UnitAttr and new builder overload - `BarrierOpToLLVM.cpp`: when `perThread` is set, emit `mbarrier.arrive` without the leader-thread predicate (no `threadIdx == 0` check), so all threads arrive independently - `barrier.py`: add `alloc_warp_barrier(num_barriers, num_warps, num_arrivals)` which sets `arrive_count = num_warps * 32 * num_arrivals` and marks the barrier as a warp barrier; `barrier_arrive` dispatches to `create_warp_barrier_arrive` for warp barriers - `types.py` / `mem_ops.py`: propagate `is_warp_barrier` flag through `mbarrier` / `mbarrier_type` and `local_view` - `triton_tlx.cc`: expose `create_warp_barrier_arrive` to the dialect builder X-link: facebookexperimental/triton#1031 Reviewed By: htyu Differential Revision: D95130632 Pulled By: tissue3
…ebookexperimental#1031) Summary: X-link: meta-pytorch/tritonbench#922 Add `alloc_warp_barrier` to TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (`__syncwarp` + thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence. Infrastructure changes: - `ArriveBarrierOp` in `TritonNvidiaGPUOps.td`: add `perThread` UnitAttr and new builder overload - `BarrierOpToLLVM.cpp`: when `perThread` is set, emit `mbarrier.arrive` without the leader-thread predicate (no `threadIdx == 0` check), so all threads arrive independently - `barrier.py`: add `alloc_warp_barrier(num_barriers, num_warps, num_arrivals)` which sets `arrive_count = num_warps * 32 * num_arrivals` and marks the barrier as a warp barrier; `barrier_arrive` dispatches to `create_warp_barrier_arrive` for warp barriers - `types.py` / `mem_ops.py`: propagate `is_warp_barrier` flag through `mbarrier` / `mbarrier_type` and `local_view` - `triton_tlx.cc`: expose `create_warp_barrier_arrive` to the dialect builder Test Plan: ``` pytest python/test/unit/language/test_tlx.py::test_alloc_warp_barrier -xvs pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_clc_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_hopper_gemm_ws_warp_barrier ``` Performance: - Blackwell ``` rm -rf ~/.triton/cache 2>/dev/null; CUDA_VISIBLE_DEVICES=2 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py --version ws ws_warp_barrier clc clc_warp_barrier Running benchmarks for: ['ws', 'ws_warp_barrier', 'clc', 'clc_warp_barrier'] (dtype=fp16) matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier clc clc_warp_barrier 0 2048.0 2048.0 2048.0 836.247528 453.055644 453.055644 597.851777 598.518283 1 4096.0 4096.0 4096.0 1126.992231 1108.092714 1108.378675 1040.447562 1039.943710 2 8192.0 8192.0 8192.0 1133.648070 1110.420373 1109.022579 959.528030 962.134251 ``` - Hopper ``` matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier 0 2048.0 2048.0 2048.0 581.029130 488.064468 489.399203 1 4096.0 4096.0 4096.0 619.317567 561.213562 564.310512 2 8192.0 8192.0 8192.0 587.325877 555.362764 549.685437 ``` Autotune Test - Blackwell tlx_matmul_ws ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_ws,tlx_matmul_ws_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/yaro1kke tlx_matmul_clc ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_clc,tlx_matmul_clc_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/drvps4oz - Hopper Reviewed By: htyu Differential Revision: D95130632 Pulled By: tissue3
…ebookexperimental#1031) Summary: X-link: meta-pytorch/tritonbench#922 Add `alloc_warp_barrier` to TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (`__syncwarp` + thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence. Infrastructure changes: - `ArriveBarrierOp` in `TritonNvidiaGPUOps.td`: add `perThread` UnitAttr and new builder overload - `BarrierOpToLLVM.cpp`: when `perThread` is set, emit `mbarrier.arrive` without the leader-thread predicate (no `threadIdx == 0` check), so all threads arrive independently - `barrier.py`: add `alloc_warp_barrier(num_barriers, num_warps, num_arrivals)` which sets `arrive_count = num_warps * 32 * num_arrivals` and marks the barrier as a warp barrier; `barrier_arrive` dispatches to `create_warp_barrier_arrive` for warp barriers - `types.py` / `mem_ops.py`: propagate `is_warp_barrier` flag through `mbarrier` / `mbarrier_type` and `local_view` - `triton_tlx.cc`: expose `create_warp_barrier_arrive` to the dialect builder Test Plan: ``` pytest python/test/unit/language/test_tlx.py::test_alloc_warp_barrier -xvs pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_clc_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_hopper_gemm_ws_warp_barrier ``` Performance: - Blackwell ``` rm -rf ~/.triton/cache 2>/dev/null; CUDA_VISIBLE_DEVICES=2 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py --version ws ws_warp_barrier clc clc_warp_barrier Running benchmarks for: ['ws', 'ws_warp_barrier', 'clc', 'clc_warp_barrier'] (dtype=fp16) matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier clc clc_warp_barrier 0 2048.0 2048.0 2048.0 836.247528 453.055644 453.055644 597.851777 598.518283 1 4096.0 4096.0 4096.0 1126.992231 1108.092714 1108.378675 1040.447562 1039.943710 2 8192.0 8192.0 8192.0 1133.648070 1110.420373 1109.022579 959.528030 962.134251 ``` - Hopper ``` matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier 0 2048.0 2048.0 2048.0 581.029130 488.064468 489.399203 1 4096.0 4096.0 4096.0 619.317567 561.213562 564.310512 2 8192.0 8192.0 8192.0 587.325877 555.362764 549.685437 ``` Autotune Test - Blackwell tlx_matmul_ws ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_ws,tlx_matmul_ws_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/yaro1kke tlx_matmul_clc ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_clc,tlx_matmul_clc_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/drvps4oz - Hopper Reviewed By: htyu Differential Revision: D95130632 Pulled By: tissue3
Summary: See facebookexperimental/triton#1031. This is adding autotune to benchmark for that change Differential Revision: D95353897
…rch#925) Summary: See facebookexperimental/triton#1031. This is adding autotune to benchmark for that change Differential Revision: D95353897
…rch#925) Summary: See facebookexperimental/triton#1031. This is adding autotune to benchmark for that change Differential Revision: D95353897
…rch#925) Summary: See facebookexperimental/triton#1031. This is adding autotune to benchmark for that change Differential Revision: D95353897 Pulled By: tissue3
Summary:
Add
alloc_warp_barrierto TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (__syncwarp+ thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence.Infrastructure changes:
ArriveBarrierOpinTritonNvidiaGPUOps.td: addperThreadUnitAttr and new builder overloadBarrierOpToLLVM.cpp: whenperThreadis set, emitmbarrier.arrivewithout the leader-thread predicate (nothreadIdx == 0check), so all threads arrive independentlybarrier.py: addalloc_warp_barrier(num_barriers, num_warps, num_arrivals)which setsarrive_count = num_warps * 32 * num_arrivalsand marks the barrier as a warp barrier;barrier_arrivedispatches tocreate_warp_barrier_arrivefor warp barrierstypes.py/mem_ops.py: propagateis_warp_barrierflag throughmbarrier/mbarrier_typeandlocal_viewtriton_tlx.cc: exposecreate_warp_barrier_arriveto the dialect builderTest Plan:
Performance:
Autotune Test
tlx_matmul_ws: https://fburl.com/everpaste/yaro1kke
tlx_matmul_clc: https://fburl.com/everpaste/drvps4oz