Skip to content

Add alloc_warp_barrier for multi-thread barrier arrival#1031

Open
tissue3 wants to merge 1 commit intofacebookexperimental:mainfrom
tissue3:pr1031
Open

Add alloc_warp_barrier for multi-thread barrier arrival#1031
tissue3 wants to merge 1 commit intofacebookexperimental:mainfrom
tissue3:pr1031

Conversation

@tissue3
Copy link
Contributor

@tissue3 tissue3 commented Mar 3, 2026

Summary:
Add alloc_warp_barrier to TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (__syncwarp + thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence.

Infrastructure changes:

  • ArriveBarrierOp in TritonNvidiaGPUOps.td: add perThread UnitAttr and new builder overload
  • BarrierOpToLLVM.cpp: when perThread is set, emit mbarrier.arrive without the leader-thread predicate (no threadIdx == 0 check), so all threads arrive independently
  • barrier.py: add alloc_warp_barrier(num_barriers, num_warps, num_arrivals) which sets arrive_count = num_warps * 32 * num_arrivals and marks the barrier as a warp barrier; barrier_arrive dispatches to create_warp_barrier_arrive for warp barriers
  • types.py / mem_ops.py: propagate is_warp_barrier flag through mbarrier / mbarrier_type and local_view
  • triton_tlx.cc: expose create_warp_barrier_arrive to the dialect builder

Test Plan:

pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws
pytest python/test/unit/language/test_tlx.py::test_alloc_warp_barrier -xvs

Performance:

rm -rf ~/.triton/cache 2>/dev/null; CUDA_VISIBLE_DEVICES=2 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py  --version ws ws_warp_barrier clc clc_warp_barrier

Running benchmarks for: ['ws', 'ws_warp_barrier', 'clc', 'clc_warp_barrier'] (dtype=fp16)
matmul-performance-fp16:
        M       N       K       cuBLAS           ws  ws_warp_barrier          clc  clc_warp_barrier
0  2048.0  2048.0  2048.0   837.552136   453.055644       453.055644   598.518283        581.658636
1  4096.0  4096.0  4096.0  1109.237446  1090.923891      1090.923891  1040.447562       1040.447562
2  8192.0  8192.0  8192.0  1117.317214  1110.384480      1109.201629   829.104263        828.504511

GEMM shows no significant difference between alloc_barriers and alloc_warp_barrier — results are within noise.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 3, 2026
@tissue3 tissue3 changed the title multi-thread barrier arrival Add alloc_warp_barrier for multi-thread barrier arrival Mar 4, 2026
@meta-codesync
Copy link

meta-codesync bot commented Mar 4, 2026

@tissue3 has imported this pull request. If you are a Meta employee, you can view this in D95130632.

Copy link
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tissue3 This is a great initial implementation. However, I'm concerned that we aren't actually eliminating the synchronization and we are only eliminating the predicate.

Also my intuition is that this this may have different impacts per shape depending on which partition is he bottleneck. Once we have confirmed the sync no longer occurs, let's expand this to a broader set of model shapes and run with TritonBench.

tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
if USE_WARP_BARRIER:
tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If after verifying the sync looks as expected we see a gain, but only a minor one (say < 1%), we might want to try applying this to all of the barriers in the kernel. That will of course be more work to support TMA.

ttgir = kernel.asm["ttgir"]
assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}"
assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}"
assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests need to dump the llir to verify the lowering pattern is as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or ptx? Especially to verify the barrier sync isn't there.

}

// TODO: Add phase result as needed.
std::stringstream ptxAsm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generate a lit test for the associated IR after lowering? I think we need to inspect the code lowering prepared for LLVM.

If I understand correctly the barrier is only half the issue and we also need to ensure the synchronization before it is removed. This is the an example full IR from AutoWS GEMM. Not 1:1 but should contain a similar pattern that we care about (and I have it available) .

This is my relevant code for the section that goes TMEM_LOAD -> barrier arrive:

nvvm.tcgen05.wait <load> loc(#loc53)
nvvm.barrier0 loc(#loc53)
%accumulator_1092 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc53)
%accumulator_1093 = llvm.sub %accumulator_1092, %102 : i32 loc(#loc53)
%accumulator_1094 = llvm.and %accumulator_1093, %94 : i32 loc(#loc53)
%accumulator_1095 = llvm.icmp "eq" %accumulator_1094, %109 : i32 loc(#loc53)
%accumulator_1096 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 mbarrier.arrive.shared::cta.b64 _, [$1];", "b,r" %accumulator_1095, %121 : (i1, !llvm.struct<(ptr<3>, i32)>) -> !llvm.void loc(#loc53)

As I understand it, this nvvm.barrier0 is responsible for the sync. So even though you example will eliminate accumulator_1092-accumulator_1095 and simplify accumulator_1096, to achieve the desired impact we will need to remove this barrier (or possibly reduce it to a sync between warps and not between threads).

cc: @htyu in case you have looked at this in more detail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the membar pass is the culprit. It's designed to protect shared memory access with bar.sync to makes sure all threads see exactly same data. Barriers could be an exception.

Copy link
Contributor Author

@tissue3 tissue3 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njriasan @htyu Thanks for reviewing! I tried removing that nvvm.barrier but still got trivial perf difference

matmul-performance-fp16:
        M       N       K       cuBLAS           ws  ws_warp_barrier          clc  clc_warp_barrier
0  2048.0  2048.0  2048.0   838.860821   453.055644       453.055644   623.543467        622.098396
1  4096.0  4096.0  4096.0  1109.237446  1090.923891      1090.923891  1040.195516       1040.447562
2  8192.0  8192.0  8192.0  1114.996704  1106.807677      1104.601605   920.826955        921.753365

Is it because I only enabled per-thread sync, not per buffer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine with neutral perf. Maybe other cases can benefit from this.

Can you check if some bar.sync ops go away on the PTX?

@htyu
Copy link
Contributor

htyu commented Mar 4, 2026

Thanks for working on this!

Also my intuition is that this this may have different impacts per shape depending on which partition is he bottleneck. Once we have confirmed the sync no longer occurs, let's expand this to a broader set of model shapes and run with TritonBench.

Right, the MemWrite<SharedMemory> attribute on the arrive op may be the source of bar.sync. Check membar.cpp. We may need to tweak that pass to relax for the perThread=true case.

I also wonder if some validation pass against num_warps declared in alloc_warp_barrier should match the num_warps for the actual arrive region makes sense.

// For perThread arrives, each thread's own program order guarantees its
// SMEM ops complete before its arrive, and the mbarrier accumulates all
// arrivals before releasing the waiter, so no CTA-wide fence is needed.
if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Does it affect performance?

ttgir = kernel.asm["ttgir"]
assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}"
assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}"
assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or ptx? Especially to verify the barrier sync isn't there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants