Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
let arguments = (ins
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$alloc,
I32Attr:$count,
Optional<I1>:$pred
Optional<I1>:$pred,
UnitAttr:$perThread
);

let assemblyFormat = [{
Expand All @@ -289,7 +290,10 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {

let builders = [
OpBuilder<(ins "Value":$alloc, "uint32_t":$count), [{
return build($_builder, $_state, alloc, count, /*pred=*/Value());
return build($_builder, $_state, alloc, count, /*pred=*/Value(), /*perThread=*/false);
}]>,
OpBuilder<(ins "Value":$alloc, "uint32_t":$count, "bool":$perThread), [{
return build($_builder, $_state, alloc, count, /*pred=*/Value(), perThread);
}]>
];

Expand Down
58 changes: 36 additions & 22 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,32 +228,46 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
curBlockInfo = funcBlockInfoMap->lookup(callee);
} else {
// Intra-function dependencies
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
// Explicit buffer
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
auto interval = allocation->getAllocatedInterval(bufferId);
interval = narrowIntervalForSubview(value, interval);
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo.syncWriteIntervals[interval].insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo.syncReadIntervals[interval].insert(op);
//
// For perThread ArriveBarrierOp, skip all SMEM hazard tracking.
// mbarrier.arrive has release semantics and mbarrier.wait has acquire
// semantics, so no CTA-wide bar.sync is needed before a perThread arrive.
// Each thread's program order guarantees its own SMEM ops are visible
// before its arrive, and the mbarrier accumulates all arrivals before
// releasing the waiter.
bool isPerThreadArrive = false;
if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op))
isPerThreadArrive = arriveOp.getPerThread();

if (!isPerThreadArrive) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure we run an accuracy test as well. I'm not sure if the general case needs to be conservative and ensure with multiple warps there is a risk that the same warp can contribute multiple arrivals. However, this practically shouldn't happen, so worst case we can add a TODO to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did do some accuracy test (see my test plan), but I am not sure if that is enough. Please let me know if there is anything other than these tests to check on

pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws_warp_barrier
pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_clc_warp_barrier

if (auto memoryEffectOpInterface =
dyn_cast<MemoryEffectOpInterface>(op)) {
// Explicit buffer
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
auto interval = allocation->getAllocatedInterval(bufferId);
interval = narrowIntervalForSubview(value, interval);
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo.syncWriteIntervals[interval].insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo.syncReadIntervals[interval].insert(op);
}
}
}
}
}
}
// If this op is may be signalling other threads asynchronously, make sure
// all shared memory transactions are complete beforehand.
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
// If this op may be signalling other threads asynchronously, make sure
// all shared memory transactions are complete beforehand.
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
}
}
scratchBufferId = allocation->getBufferId(op);
}
Expand Down
103 changes: 100 additions & 3 deletions python/test/unit/language/test_tlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def test_async_dot_blackwell_2cta_tma(device):
with pytest.raises(Exception) as e:
run_async_dot_blackwell_2cta_tma(device, False, 128)
assert isinstance(e.value, triton.CompilationError), "expecting a compilation error"
assert 'only supports M=128 per CTA for pair-CTA mma' in e.value.error_message
assert "only supports M=128 per CTA for pair-CTA mma" in e.value.error_message


def run_async_dot_blackwell_2cta_tma(device, A_TMEM, SAMPLE_M):
Expand Down Expand Up @@ -2703,6 +2703,103 @@ def test_wait_arrive_ws(BLOCK_SIZE, device):
and (ttgir.count("default {") == 1) and (ttgir.count("partition0") == 1)), f"TTGIR {ttgir}"


@triton.jit
def tlx_square_warp_barrier(
x_ptr,
z_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
NUM_WARPS: tl.constexpr,
):
"""
Warp-specialized kernel demonstrating perThread barrier arrives with SMEM.
Producer loads global → stores SMEM → arrives (perThread, no bar.sync).
Consumer waits → loads SMEM → computes z=x*x → stores global → arrives.

This mirrors the GEMM epilogue pattern where local_load from shared memory
is followed by barrier_arrive to signal the buffer is consumed.
"""
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE

# Warp barriers: each thread arrives independently (no leader sync)
bars = tlx.alloc_warp_barrier(num_barriers=2, num_warps=NUM_WARPS)
b0 = tlx.local_view(bars, 0)
b1 = tlx.local_view(bars, 1)

# Shared memory buffer for producer-consumer data transfer
buf = tlx.local_alloc((BLOCK_SIZE, ), tl.float32, 1)
smem = tlx.local_view(buf, 0)

phase = 0
with tlx.async_tasks():
with tlx.async_task("default"):
tlx.barrier_wait(bar=b1, phase=phase ^ 1)
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# Producer: load from global, store to SMEM
x = tl.load(x_ptr + offsets, mask=mask)
tlx.local_store(smem, x)
# KEY PATTERN: SMEM write → perThread arrive (no bar.sync)
tlx.barrier_arrive(bar=b0)

with tlx.async_task(num_warps=4):
tlx.barrier_wait(bar=b0, phase=phase)

offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Consumer: load from SMEM, compute, store to global
data = tlx.local_load(smem)
z = data * data
tl.store(z_ptr + offsets, z, mask=mask)
# KEY PATTERN: SMEM read → perThread arrive (no bar.sync)
tlx.barrier_arrive(bar=b0)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
@pytest.mark.parametrize("BLOCK_SIZE", [(1024)])
@pytest.mark.parametrize("num_warps", [4])
def test_alloc_warp_barrier(BLOCK_SIZE, num_warps, device):
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=device)
z = torch.empty_like(x)
n_elements = x.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
kernel = tlx_square_warp_barrier[grid](
x,
z,
n_elements,
BLOCK_SIZE,
num_warps,
num_warps=num_warps,
)

z_ref = x * x
torch.testing.assert_close(z, z_ref, check_dtype=False)

# Verify TTGIR: warp-specialized with perThread arrives
ttgir = kernel.asm["ttgir"]
assert "perThread" in ttgir, f"Expected perThread attrs in TTGIR:\n{ttgir}"
assert "ttng.arrive_barrier" in ttgir, f"Expected arrive_barrier in TTGIR:\n{ttgir}"

# Verify LLIR: perThread arrives use per-thread lowering (no leader predicate)
llir = kernel.asm["llir"]
# Per-thread arrive emits unpredicated: mbarrier.arrive.shared::cta.b64 _, [$0]
assert "mbarrier.arrive.shared::cta.b64 _, [$0]" in llir, (
f"Expected unpredicated per-thread mbarrier.arrive in LLIR:\n{llir}")
# Leader pattern would emit predicated: @$0 mbarrier.arrive
assert "@$0 mbarrier.arrive" not in llir, f"Unexpected leader-predicated mbarrier.arrive in LLIR:\n{llir}"
# No bar.sync immediately before mbarrier.arrive (membar pass should skip
# perThread arrives for both full-range and per-buffer SMEM hazards).
# Other bar.sync may exist (e.g. before wait_barrier) — that's fine.

assert not re.search(r"barrier\.cta\.sync.*\n.*mbarrier\.arrive",
llir), (f"Unexpected bar.sync before mbarrier.arrive in LLIR:\n{llir}")


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
def test_barrier_live_range(device):

Expand Down Expand Up @@ -6382,13 +6479,13 @@ def bulk_copy_kernel(
ttgir = kernel.asm["ttgir"]
assert "ttg.async_copy_global_to_local" in ttgir, "Expected async_copy_global_to_local in TTGIR"
assert "useBulk = true" in ttgir, "Expected useBulk = true in TTGIR"
assert "ttng.async_store" in ttgir, ("Expected async_store in TTGIR")
assert "ttng.async_store" in ttgir, "Expected async_store in TTGIR"

# Verify PTX contains the bulk copy instructions
ptx = kernel.asm["ptx"]
assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes" in ptx, (
"Expected cp.async.bulk gmem->smem in PTX")
assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, ("Expected cp.async.bulk smem->gmem in PTX")
assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, "Expected cp.async.bulk smem->gmem in PTX"

# Verify correctness
torch.testing.assert_close(src, dst)
Expand Down
49 changes: 49 additions & 0 deletions test/Analysis/test-membar-ttng.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,52 @@ module attributes {"ttg.num-warps" = 4 : i32} {
tt.return
}
}

// -----

// Verify that a perThread arrive after a shared memory write does NOT get a
// gpu.barrier inserted before it. The perThread attribute opts out of the
// CTA-wide fence because each thread's program order guarantees its own SMEM
// ops complete before its arrive.

#shared_pt = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#blocked_pt = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A_SHARED_pt = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} {
// CHECK-LABEL: @no_barrier_before_perthread_arrive
tt.func @no_barrier_before_perthread_arrive(%arg: tensor<32x16xf16, #blocked_pt>) {
%alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable>
%barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable>
// CHECK: ttg.local_store
// CHECK-NEXT: ttng.arrive_barrier
// CHECK-NOT: gpu.barrier
// CHECK: tt.return
ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_pt> -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable>
ttng.arrive_barrier %barrier, 1 {perThread} : !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable>
tt.return
}
}

// -----

// Verify that a regular (non-perThread) arrive after a shared memory write
// DOES get a gpu.barrier inserted before it (existing behavior preserved).

#shared_reg = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#blocked_reg = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#A_SHARED_reg = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} {
// CHECK-LABEL: @barrier_before_regular_arrive
tt.func @barrier_before_regular_arrive(%arg: tensor<32x16xf16, #blocked_reg>) {
%alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable>
%barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable>
// CHECK: ttg.local_store
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: ttng.arrive_barrier
ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_reg> -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable>
ttng.arrive_barrier %barrier, 1 : !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable>
tt.return
}
}
9 changes: 9 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.return
}

// CHECK-LABEL: arrive_barrier_per_thread
tt.func @arrive_barrier_per_thread(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) {
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
// CHECK-NOT: llvm.icmp "eq"
// CHECK: "mbarrier.arrive.shared::cta.b64 _, [$0], 2;", "r" %arg0
ttng.arrive_barrier %alloc, 2 {perThread} : !ttg.memdesc<1xi64, #shared0, #smem>
tt.return
}

// CHECK-LABEL: arrive_barrier_named
tt.func @arrive_barrier_named(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
%c9_i32 = arith.constant 9 : i32
Expand Down
81 changes: 55 additions & 26 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,40 +254,69 @@ struct ArriveBarrierOpConversion
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool isPerThread = op.getPerThread();

bool isRemoteBarrier = false;
if (auto barType = dyn_cast<ttg::MemDescType>(op.getAlloc().getType())) {
isRemoteBarrier =
isa<ttng::SharedClusterMemorySpaceAttr>(barType.getMemorySpace());
}

// TODO: Add phase result as needed.
std::stringstream ptxAsm;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generate a lit test for the associated IR after lowering? I think we need to inspect the code lowering prepared for LLVM.

If I understand correctly the barrier is only half the issue and we also need to ensure the synchronization before it is removed. This is the an example full IR from AutoWS GEMM. Not 1:1 but should contain a similar pattern that we care about (and I have it available) .

This is my relevant code for the section that goes TMEM_LOAD -> barrier arrive:

nvvm.tcgen05.wait <load> loc(#loc53)
nvvm.barrier0 loc(#loc53)
%accumulator_1092 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc53)
%accumulator_1093 = llvm.sub %accumulator_1092, %102 : i32 loc(#loc53)
%accumulator_1094 = llvm.and %accumulator_1093, %94 : i32 loc(#loc53)
%accumulator_1095 = llvm.icmp "eq" %accumulator_1094, %109 : i32 loc(#loc53)
%accumulator_1096 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 mbarrier.arrive.shared::cta.b64 _, [$1];", "b,r" %accumulator_1095, %121 : (i1, !llvm.struct<(ptr<3>, i32)>) -> !llvm.void loc(#loc53)

As I understand it, this nvvm.barrier0 is responsible for the sync. So even though you example will eliminate accumulator_1092-accumulator_1095 and simplify accumulator_1096, to achieve the desired impact we will need to remove this barrier (or possibly reduce it to a sync between warps and not between threads).

cc: @htyu in case you have looked at this in more detail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the membar pass is the culprit. It's designed to protect shared memory access with bar.sync to makes sure all threads see exactly same data. Barriers could be an exception.

Copy link
Contributor Author

@tissue3 tissue3 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njriasan @htyu Thanks for reviewing! I tried removing that nvvm.barrier but still got trivial perf difference

matmul-performance-fp16:
        M       N       K       cuBLAS           ws  ws_warp_barrier          clc  clc_warp_barrier
0  2048.0  2048.0  2048.0   838.860821   453.055644       453.055644   623.543467        622.098396
1  4096.0  4096.0  4096.0  1109.237446  1090.923891      1090.923891  1040.195516       1040.447562
2  8192.0  8192.0  8192.0  1114.996704  1106.807677      1104.601605   920.826955        921.753365

Is it because I only enabled per-thread sync, not per buffer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine with neutral perf. Maybe other cases can benefit from this.

Can you check if some bar.sync ops go away on the PTX?

Copy link
Contributor Author

@tissue3 tissue3 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine with neutral perf. Maybe other cases can benefit from this.

Can you check if some bar.sync ops go away on the PTX?

The current implementation is the sync before mbarrier.arrive.shared is removed: https://fburl.com/phabricator/lz64ljio, but not other syncs

ptxAsm << "@$0 mbarrier.arrive.shared::";
if (isRemoteBarrier)
ptxAsm << "cluster";
else
ptxAsm << "cta";
ptxAsm << ".b64 _, [$1]";
if (op.getCount() > 1) {
ptxAsm << ", " << op.getCount();
}
ptxAsm << ";";

TritonLLVMOpBuilder b(op.getLoc(), rewriter);
Value id = getThreadId(rewriter, op.getLoc());
Value pred = b.icmp_eq(id, b.i32_val(0));
if (op.getPred())
pred = b.and_(pred, adaptor.getPred());
if (isPerThread) {
// Warp arrive: every thread arrives independently, no leader pattern.
bool hasPred = !!op.getPred();
std::stringstream ptxAsm;
if (hasPred) {
ptxAsm << "@$0 ";
}
ptxAsm << "mbarrier.arrive.shared::cta.b64 _, ["
<< (hasPred ? "$1" : "$0") << "]";
if (op.getCount() > 1) {
ptxAsm << ", " << op.getCount();
}
ptxAsm << ";";

PTXBuilder ptxBuilder;
SmallVector<PTXBuilder::Operand *, 2> operands = {
ptxBuilder.newOperand(pred, "b"),
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};
PTXBuilder ptxBuilder;
SmallVector<PTXBuilder::Operand *, 2> operands;
if (hasPred) {
operands.push_back(ptxBuilder.newOperand(adaptor.getPred(), "b"));
}
operands.push_back(ptxBuilder.newOperand(adaptor.getAlloc(), "r"));

auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
auto voidTy = void_ty(getContext());
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
auto voidTy = void_ty(getContext());
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
} else {
// Leader pattern: only thread 0 arrives.
std::stringstream ptxAsm;
ptxAsm << "@$0 mbarrier.arrive.shared::";
if (isRemoteBarrier)
ptxAsm << "cluster";
else
ptxAsm << "cta";
ptxAsm << ".b64 _, [$1]";
if (op.getCount() > 1) {
ptxAsm << ", " << op.getCount();
}
ptxAsm << ";";

TritonLLVMOpBuilder b(op.getLoc(), rewriter);
Value id = getThreadId(rewriter, op.getLoc());
Value pred = b.icmp_eq(id, b.i32_val(0));
if (op.getPred())
pred = b.and_(pred, adaptor.getPred());

PTXBuilder ptxBuilder;
SmallVector<PTXBuilder::Operand *, 2> operands = {
ptxBuilder.newOperand(pred, "b"),
ptxBuilder.newOperand(adaptor.getAlloc(), "r")};

auto arriveOp = *ptxBuilder.create<>(ptxAsm.str());
arriveOp(operands, /*onlyAttachMLIRArgs=*/true);
auto voidTy = void_ty(getContext());
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);
}

rewriter.eraseOp(op);
return success();
Expand Down
6 changes: 6 additions & 0 deletions third_party/tlx/dialect/triton_tlx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ void init_triton_tlx_ir(py::module &&m) {
[](TritonOpBuilder &self, Value mbarrerLoc, int arriveCount) -> void {
self.create<ttng::ArriveBarrierOp>(mbarrerLoc, arriveCount);
})
.def("create_warp_barrier_arrive",
[](TritonOpBuilder &self, Value mbarrierLoc,
int arriveCount) -> void {
self.create<ttng::ArriveBarrierOp>(mbarrierLoc, arriveCount,
/*perThread=*/true);
})
.def("create_named_barrier_wait",
[](TritonOpBuilder &self, Value barrier, Value numThreads) -> void {
self.create<ttng::NamedBarrierWaitOp>(barrier, numThreads);
Expand Down
Loading
Loading