Skip to content

Commit f0d855e

Browse files
committed
multi-thread barrier arrival
1 parent ce6394b commit f0d855e

File tree

14 files changed

+313
-50
lines changed

14 files changed

+313
-50
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
280280
let arguments = (ins
281281
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$alloc,
282282
I32Attr:$count,
283-
Optional<I1>:$pred
283+
Optional<I1>:$pred,
284+
UnitAttr:$perThread
284285
);
285286

286287
let assemblyFormat = [{
@@ -289,7 +290,10 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
289290

290291
let builders = [
291292
OpBuilder<(ins "Value":$alloc, "uint32_t":$count), [{
292-
return build($_builder, $_state, alloc, count, /*pred=*/Value());
293+
return build($_builder, $_state, alloc, count, /*pred=*/Value(), /*perThread=*/false);
294+
}]>,
295+
OpBuilder<(ins "Value":$alloc, "uint32_t":$count, "bool":$perThread), [{
296+
return build($_builder, $_state, alloc, count, /*pred=*/Value(), perThread);
293297
}]>
294298
];
295299

lib/Analysis/Membar.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,15 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
250250
}
251251
// If this op is may be signalling other threads asynchronously, make sure
252252
// all shared memory transactions are complete beforehand.
253-
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
254-
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
255-
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
256-
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
253+
// For perThread arrives, each thread's own program order guarantees its
254+
// SMEM ops complete before its arrive, and the mbarrier accumulates all
255+
// arrivals before releasing the waiter, so no CTA-wide fence is needed.
256+
if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
257+
if (!arriveOp.getPerThread()) {
258+
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
259+
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
260+
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
261+
}
257262
}
258263
scratchBufferId = allocation->getBufferId(op);
259264
}

python/test/unit/language/test_tlx.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def test_async_dot_blackwell_2cta_tma(device):
13021302
with pytest.raises(Exception) as e:
13031303
run_async_dot_blackwell_2cta_tma(device, False, 128)
13041304
assert isinstance(e.value, triton.CompilationError), "expecting a compilation error"
1305-
assert 'only supports M=128 per CTA for pair-CTA mma' in e.value.error_message
1305+
assert "only supports M=128 per CTA for pair-CTA mma" in e.value.error_message
13061306

13071307

13081308
def run_async_dot_blackwell_2cta_tma(device, A_TMEM, SAMPLE_M):
@@ -2703,6 +2703,84 @@ def test_wait_arrive_ws(BLOCK_SIZE, device):
27032703
and (ttgir.count("default {") == 1) and (ttgir.count("partition0") == 1)), f"TTGIR {ttgir}"
27042704

27052705

2706+
@triton.jit
2707+
def tlx_square_warp_barrier(
2708+
x_ptr,
2709+
z_ptr,
2710+
n_elements,
2711+
BLOCK_SIZE: tl.constexpr,
2712+
NUM_WARPS: tl.constexpr,
2713+
):
2714+
"""
2715+
Test warp barrier: all threads arrive independently (no leader pattern).
2716+
Uses alloc_warp_barrier instead of alloc_barriers.
2717+
"""
2718+
pid = tl.program_id(axis=0)
2719+
block_start = pid * BLOCK_SIZE
2720+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
2721+
mask = offsets < n_elements
2722+
2723+
bars = tlx.alloc_warp_barrier(num_barriers=1, num_warps=NUM_WARPS)
2724+
bar = tlx.local_view(bars, 0)
2725+
2726+
x = tl.load(x_ptr + offsets, mask=mask)
2727+
2728+
p = 0
2729+
tlx.barrier_arrive(bar=bar)
2730+
tlx.barrier_wait(bar=bar, phase=p)
2731+
2732+
z = x * x
2733+
2734+
p = p ^ 1
2735+
tlx.barrier_arrive(bar=bar)
2736+
tlx.barrier_wait(bar=bar, phase=p)
2737+
2738+
tl.store(z_ptr + offsets, z, mask=mask)
2739+
2740+
p = p ^ 1
2741+
tlx.barrier_arrive(bar=bar)
2742+
tlx.barrier_wait(bar=bar, phase=0)
2743+
2744+
2745+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
2746+
@pytest.mark.parametrize("BLOCK_SIZE", [(1024)])
2747+
@pytest.mark.parametrize("num_warps", [4])
2748+
def test_alloc_warp_barrier(BLOCK_SIZE, num_warps, device):
2749+
torch.manual_seed(0)
2750+
size = 98432
2751+
x = torch.rand(size, device=device)
2752+
z = torch.empty_like(x)
2753+
n_elements = x.numel()
2754+
2755+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
2756+
kernel = tlx_square_warp_barrier[grid](
2757+
x,
2758+
z,
2759+
n_elements,
2760+
BLOCK_SIZE,
2761+
num_warps,
2762+
num_warps=num_warps,
2763+
)
2764+
2765+
z_ref = x * x
2766+
torch.testing.assert_close(z, z_ref, check_dtype=False)
2767+
2768+
# Verify IR uses arrive_barrier with perThread attribute
2769+
ttgir = kernel.asm["ttgir"]
2770+
assert ttgir.count("ttng.init_barrier") == 1, f"Expected 1 init_barrier in TTGIR:\n{ttgir}"
2771+
assert ttgir.count("ttng.arrive_barrier") == 3, f"Expected 3 arrive_barrier in TTGIR:\n{ttgir}"
2772+
assert ttgir.count("perThread") == 3, f"Expected 3 perThread attrs in TTGIR:\n{ttgir}"
2773+
assert ttgir.count("ttng.wait_barrier") == 3, f"Expected 3 wait_barrier in TTGIR:\n{ttgir}"
2774+
2775+
# Verify LLIR: perThread arrives use per-thread lowering (no leader predicate)
2776+
llir = kernel.asm["llir"]
2777+
# Per-thread arrive emits unpredicated: mbarrier.arrive.shared::cta.b64 _, [$0]
2778+
assert "mbarrier.arrive.shared::cta.b64 _, [$0]" in llir, (
2779+
f"Expected unpredicated per-thread mbarrier.arrive in LLIR:\n{llir}")
2780+
# Leader pattern would emit predicated: @$0 mbarrier.arrive
2781+
assert "@$0 mbarrier.arrive" not in llir, f"Unexpected leader-predicated mbarrier.arrive in LLIR:\n{llir}"
2782+
2783+
27062784
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
27072785
def test_barrier_live_range(device):
27082786

@@ -6382,13 +6460,13 @@ def bulk_copy_kernel(
63826460
ttgir = kernel.asm["ttgir"]
63836461
assert "ttg.async_copy_global_to_local" in ttgir, "Expected async_copy_global_to_local in TTGIR"
63846462
assert "useBulk = true" in ttgir, "Expected useBulk = true in TTGIR"
6385-
assert "ttng.async_store" in ttgir, ("Expected async_store in TTGIR")
6463+
assert "ttng.async_store" in ttgir, "Expected async_store in TTGIR"
63866464

63876465
# Verify PTX contains the bulk copy instructions
63886466
ptx = kernel.asm["ptx"]
63896467
assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes" in ptx, (
63906468
"Expected cp.async.bulk gmem->smem in PTX")
6391-
assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, ("Expected cp.async.bulk smem->gmem in PTX")
6469+
assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, "Expected cp.async.bulk smem->gmem in PTX"
63926470

63936471
# Verify correctness
63946472
torch.testing.assert_close(src, dst)

test/Analysis/test-membar-ttng.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,52 @@ module attributes {"ttg.num-warps" = 4 : i32} {
194194
tt.return
195195
}
196196
}
197+
198+
// -----
199+
200+
// Verify that a perThread arrive after a shared memory write does NOT get a
201+
// gpu.barrier inserted before it. The perThread attribute opts out of the
202+
// CTA-wide fence because each thread's program order guarantees its own SMEM
203+
// ops complete before its arrive.
204+
205+
#shared_pt = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
206+
#blocked_pt = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
207+
#A_SHARED_pt = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
208+
209+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} {
210+
// CHECK-LABEL: @no_barrier_before_perthread_arrive
211+
tt.func @no_barrier_before_perthread_arrive(%arg: tensor<32x16xf16, #blocked_pt>) {
212+
%alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable>
213+
%barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable>
214+
// CHECK: ttg.local_store
215+
// CHECK-NEXT: ttng.arrive_barrier
216+
// CHECK-NOT: gpu.barrier
217+
// CHECK: tt.return
218+
ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_pt> -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable>
219+
ttng.arrive_barrier %barrier, 1 {perThread} : !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable>
220+
tt.return
221+
}
222+
}
223+
224+
// -----
225+
226+
// Verify that a regular (non-perThread) arrive after a shared memory write
227+
// DOES get a gpu.barrier inserted before it (existing behavior preserved).
228+
229+
#shared_reg = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
230+
#blocked_reg = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
231+
#A_SHARED_reg = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
232+
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} {
234+
// CHECK-LABEL: @barrier_before_regular_arrive
235+
tt.func @barrier_before_regular_arrive(%arg: tensor<32x16xf16, #blocked_reg>) {
236+
%alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable>
237+
%barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable>
238+
// CHECK: ttg.local_store
239+
// CHECK-NEXT: gpu.barrier
240+
// CHECK-NEXT: ttng.arrive_barrier
241+
ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_reg> -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable>
242+
ttng.arrive_barrier %barrier, 1 : !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable>
243+
tt.return
244+
}
245+
}

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
6868
tt.return
6969
}
7070

71+
// CHECK-LABEL: arrive_barrier_per_thread
72+
tt.func @arrive_barrier_per_thread(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
73+
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
74+
// CHECK-NOT: llvm.icmp "eq"
75+
// CHECK: "mbarrier.arrive.shared::cta.b64 _, [$0], 2;", "r" %arg0
76+
ttng.arrive_barrier %alloc, 2 {perThread} : !ttg.memdesc<1xi64, #shared0, #smem>
77+
tt.return
78+
}
79+
7180
// CHECK-LABEL: arrive_barrier_named
7281
tt.func @arrive_barrier_named(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
7382
%c9_i32 = arith.constant 9 : i32

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -254,40 +254,69 @@ struct ArriveBarrierOpConversion
254254
LogicalResult
255255
matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
256256
ConversionPatternRewriter &rewriter) const override {
257+
bool isPerThread = op.getPerThread();
258+
257259
bool isRemoteBarrier = false;
258260
if (auto barType = dyn_cast<ttg::MemDescType>(op.getAlloc().getType())) {
259261
isRemoteBarrier =
260262
isa<ttng::SharedClusterMemorySpaceAttr>(barType.getMemorySpace());
261263
}
262264

263-
// TODO: Add phase result as needed.
264-
std::stringstream ptxAsm;
265-
ptxAsm << "@$0 mbarrier.arrive.shared::";
266-
if (isRemoteBarrier)
267-
ptxAsm << "cluster";
268-
else
269-
ptxAsm << "cta";
270-
ptxAsm << ".b64 _, [$1]";
271-
if (op.getCount() > 1) {
272-
ptxAsm << ", " << op.getCount();
273-
}
274-
ptxAsm << ";";
275-
276-
TritonLLVMOpBuilder b(op.getLoc(), rewriter);
277-
Value id = getThreadId(rewriter, op.getLoc());
278-
Value pred = b.icmp_eq(id, b.i32_val(0));
279-
if (op.getPred())
280-
pred = b.and_(pred, adaptor.getPred());
265+
if (isPerThread) {
266+
// Warp arrive: every thread arrives independently, no leader pattern.
267+
bool hasPred = !!op.getPred();
268+
std::stringstream ptxAsm;
269+
if (hasPred) {
270+
ptxAsm << "@$0 ";
271+
}
272+
ptxAsm << "mbarrier.arrive.shared::cta.b64 _, ["
273+
<< (hasPred ? "$1" : "$0") << "]";
274+
if (op.getCount() > 1) {
275+
ptxAsm << ", " << op.getCount();
276+
}
277+
ptxAsm << ";";
281278

282-
PTXBuilder ptxBuilder;
283-
SmallVector<PTXBuilder::Operand *, 2> operands = {
284-
ptxBuilder.newOperand(pred, "b"),
285-
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};
279+
PTXBuilder ptxBuilder;
280+
SmallVector<PTXBuilder::Operand *, 2> operands;
281+
if (hasPred) {
282+
operands.push_back(ptxBuilder.newOperand(adaptor.getPred(), "b"));
283+
}
284+
operands.push_back(ptxBuilder.newOperand(adaptor.getAlloc(), "r"));
286285

287-
auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
288-
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
289-
auto voidTy = void_ty(getContext());
290-
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
286+
auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
287+
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
288+
auto voidTy = void_ty(getContext());
289+
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
290+
} else {
291+
// Leader pattern: only thread 0 arrives.
292+
std::stringstream ptxAsm;
293+
ptxAsm << "@$0 mbarrier.arrive.shared::";
294+
if (isRemoteBarrier)
295+
ptxAsm << "cluster";
296+
else
297+
ptxAsm << "cta";
298+
ptxAsm << ".b64 _, [$1]";
299+
if (op.getCount() > 1) {
300+
ptxAsm << ", " << op.getCount();
301+
}
302+
ptxAsm << ";";
303+
304+
TritonLLVMOpBuilder b(op.getLoc(), rewriter);
305+
Value id = getThreadId(rewriter, op.getLoc());
306+
Value pred = b.icmp_eq(id, b.i32_val(0));
307+
if (op.getPred())
308+
pred = b.and_(pred, adaptor.getPred());
309+
310+
PTXBuilder ptxBuilder;
311+
SmallVector<PTXBuilder::Operand *, 2> operands = {
312+
ptxBuilder.newOperand(pred, "b"),
313+
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};
314+
315+
auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
316+
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
317+
auto voidTy = void_ty(getContext());
318+
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
319+
}
291320

292321
rewriter.eraseOp(op);
293322
return success();

third_party/tlx/dialect/triton_tlx.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,12 @@ void init_triton_tlx_ir(py::module &&m) {
320320
[](TritonOpBuilder &self, Value mbarrerLoc, int arriveCount) -> void {
321321
self.create<ttng::ArriveBarrierOp>(mbarrerLoc, arriveCount);
322322
})
323+
.def("create_warp_barrier_arrive",
324+
[](TritonOpBuilder &self, Value mbarrierLoc,
325+
int arriveCount) -> void {
326+
self.create<ttng::ArriveBarrierOp>(mbarrierLoc, arriveCount,
327+
/*perThread=*/true);
328+
})
323329
.def("create_named_barrier_wait",
324330
[](TritonOpBuilder &self, Value barrier, Value numThreads) -> void {
325331
self.create<ttng::NamedBarrierWaitOp>(barrier, numThreads);

third_party/tlx/language/tlx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .async_task_utils import async_task, async_tasks
33
from .barrier import (
44
alloc_barriers,
5+
alloc_warp_barrier,
56
barrier_arrive,
67
barrier_expect_bytes,
78
barrier_wait,
@@ -138,6 +139,7 @@
138139
# barriers
139140
"cluster_barrier",
140141
"alloc_barriers",
142+
"alloc_warp_barrier",
141143
"barrier_expect_bytes",
142144
"barrier_wait",
143145
"barrier_arrive",

0 commit comments

Comments
 (0)