From a52ab67ec1d46cd96bd55ad953565397b851b113 Mon Sep 17 00:00:00 2001 From: tissue030 Date: Thu, 5 Mar 2026 01:39:07 -0800 Subject: [PATCH] [TLX] Add `alloc_warp_barrier` for multi-thread barrier arrival (#1031) Summary: X-link: https://github.com/meta-pytorch/tritonbench/pull/922 Add `alloc_warp_barrier` to TLX, enabling barrier arrival where every thread arrives independently instead of using the default leader-thread pattern (`__syncwarp` + thread 0 arrive). This eliminates unnecessary warp synchronization and improves performance when there is warp divergence. Infrastructure changes: - `ArriveBarrierOp` in `TritonNvidiaGPUOps.td`: add `perThread` UnitAttr and new builder overload - `BarrierOpToLLVM.cpp`: when `perThread` is set, emit `mbarrier.arrive` without the leader-thread predicate (no `threadIdx == 0` check), so all threads arrive independently - `barrier.py`: add `alloc_warp_barrier(num_barriers, num_warps, num_arrivals)` which sets `arrive_count = num_warps * 32 * num_arrivals` and marks the barrier as a warp barrier; `barrier_arrive` dispatches to `create_warp_barrier_arrive` for warp barriers - `types.py` / `mem_ops.py`: propagate `is_warp_barrier` flag through `mbarrier` / `mbarrier_type` and `local_view` - `triton_tlx.cc`: expose `create_warp_barrier_arrive` to the dialect builder Test Plan: ``` pytest python/test/unit/language/test_tlx.py::test_alloc_warp_barrier -xvs pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_ws_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_blackwell_gemm_clc_warp_barrier pytest third_party/tlx/tutorials/testing/test_correctness.py::test_hopper_gemm_ws_warp_barrier ``` Performance: - Blackwell ``` rm -rf ~/.triton/cache 2>/dev/null; CUDA_VISIBLE_DEVICES=2 bash ~/fbsource/fbcode/ads_mkl/benchmarks/denoise.sh python third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py --version ws ws_warp_barrier clc clc_warp_barrier Running benchmarks for: ['ws', 'ws_warp_barrier', 'clc', 'clc_warp_barrier'] (dtype=fp16) matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier clc clc_warp_barrier 0 2048.0 2048.0 2048.0 836.247528 453.055644 453.055644 597.851777 598.518283 1 4096.0 4096.0 4096.0 1126.992231 1108.092714 1108.378675 1040.447562 1039.943710 2 8192.0 8192.0 8192.0 1133.648070 1110.420373 1109.022579 959.528030 962.134251 ``` - Hopper ``` matmul-performance-fp16: M N K cuBLAS ws ws_warp_barrier 0 2048.0 2048.0 2048.0 581.029130 488.064468 489.399203 1 4096.0 4096.0 4096.0 619.317567 561.213562 564.310512 2 8192.0 8192.0 8192.0 587.325877 555.362764 549.685437 ``` Autotune Test - Blackwell tlx_matmul_ws ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_ws,tlx_matmul_ws_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/yaro1kke tlx_matmul_clc ``` TRITON_PRINT_AUTOTUNING=1 CUDA_VISIBLE_DEVICES=6 ~/fbsource/fbcode/triton/scripts/denoise.sh buck2 run mode/opt -m ovr_config//triton:beta -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run_lite -- --op gemm --only tlx_matmul_clc,tlx_matmul_clc_warp_barrier --metrics accuracy,tflops --force --layout tt --rep 3000 --sleep 1.0 --input-loader 'pytorch/tritonbench/tritonbench/data/input_configs/fb/ads_omnifm_v5/gemm.json' --metrics accuracy,tflops,speedup ``` output: https://fburl.com/everpaste/drvps4oz - Hopper Reviewed By: htyu Differential Revision: D95130632 Pulled By: tissue3 --- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 8 +- lib/Analysis/Membar.cpp | 58 ++++++---- python/test/unit/language/test_tlx.py | 103 +++++++++++++++++- test/Analysis/test-membar-ttng.mlir | 49 +++++++++ test/Conversion/tritonnvidiagpu_to_llvm.mlir | 9 ++ .../TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp | 81 +++++++++----- third_party/tlx/dialect/triton_tlx.cc | 6 + third_party/tlx/language/tlx/__init__.py | 2 + third_party/tlx/language/tlx/barrier.py | 47 +++++++- third_party/tlx/language/tlx/mem_ops.py | 3 +- third_party/tlx/language/tlx/types.py | 10 +- .../tlx/tutorials/blackwell_gemm_clc.py | 17 ++- .../tlx/tutorials/blackwell_gemm_ws.py | 19 +++- third_party/tlx/tutorials/hopper_gemm_ws.py | 17 ++- .../testing/test_blackwell_gemm_perf.py | 15 ++- .../tlx/tutorials/testing/test_correctness.py | 29 ++++- .../testing/test_hopper_gemm_perf.py | 10 +- 17 files changed, 406 insertions(+), 77 deletions(-) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 7ebe335353..ee6a8e7094 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -280,7 +280,8 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> { let arguments = (ins Arg, MemWrite]>:$alloc, I32Attr:$count, - Optional:$pred + Optional:$pred, + UnitAttr:$perThread ); let assemblyFormat = [{ @@ -289,7 +290,10 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> { let builders = [ OpBuilder<(ins "Value":$alloc, "uint32_t":$count), [{ - return build($_builder, $_state, alloc, count, /*pred=*/Value()); + return build($_builder, $_state, alloc, count, /*pred=*/Value(), /*perThread=*/false); + }]>, + OpBuilder<(ins "Value":$alloc, "uint32_t":$count, "bool":$perThread), [{ + return build($_builder, $_state, alloc, count, /*pred=*/Value(), perThread); }]> ]; diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 4c1cf09f6c..fc0987b277 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -228,32 +228,46 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, curBlockInfo = funcBlockInfoMap->lookup(callee); } else { // Intra-function dependencies - if (auto memoryEffectOpInterface = dyn_cast(op)) { - // Explicit buffer - SmallVector> - effectInstances; - memoryEffectOpInterface.getEffects(effectInstances); - for (auto effectInstance : effectInstances) { - if (auto value = effectInstance.getValue()) { - for (auto bufferId : allocation->getBufferIds(value)) { - if (bufferId != Allocation::InvalidBufferId) { - auto interval = allocation->getAllocatedInterval(bufferId); - interval = narrowIntervalForSubview(value, interval); - if (isa(effectInstance.getEffect())) - curBlockInfo.syncWriteIntervals[interval].insert(op); - else if (isa(effectInstance.getEffect())) - curBlockInfo.syncReadIntervals[interval].insert(op); + // + // For perThread ArriveBarrierOp, skip all SMEM hazard tracking. + // mbarrier.arrive has release semantics and mbarrier.wait has acquire + // semantics, so no CTA-wide bar.sync is needed before a perThread arrive. + // Each thread's program order guarantees its own SMEM ops are visible + // before its arrive, and the mbarrier accumulates all arrivals before + // releasing the waiter. + bool isPerThreadArrive = false; + if (auto arriveOp = dyn_cast(op)) + isPerThreadArrive = arriveOp.getPerThread(); + + if (!isPerThreadArrive) { + if (auto memoryEffectOpInterface = + dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + auto interval = allocation->getAllocatedInterval(bufferId); + interval = narrowIntervalForSubview(value, interval); + if (isa(effectInstance.getEffect())) + curBlockInfo.syncWriteIntervals[interval].insert(op); + else if (isa(effectInstance.getEffect())) + curBlockInfo.syncReadIntervals[interval].insert(op); + } } } } } - } - // If this op is may be signalling other threads asynchronously, make sure - // all shared memory transactions are complete beforehand. - if (isa(op)) { - Interval allIntervals(0, std::numeric_limits::max()); - curBlockInfo.syncWriteIntervals[allIntervals].insert(op); - curBlockInfo.syncReadIntervals[allIntervals].insert(op); + // If this op may be signalling other threads asynchronously, make sure + // all shared memory transactions are complete beforehand. + if (isa(op)) { + Interval allIntervals(0, std::numeric_limits::max()); + curBlockInfo.syncWriteIntervals[allIntervals].insert(op); + curBlockInfo.syncReadIntervals[allIntervals].insert(op); + } } scratchBufferId = allocation->getBufferId(op); } diff --git a/python/test/unit/language/test_tlx.py b/python/test/unit/language/test_tlx.py index 9bc5118681..6706468876 100644 --- a/python/test/unit/language/test_tlx.py +++ b/python/test/unit/language/test_tlx.py @@ -1302,7 +1302,7 @@ def test_async_dot_blackwell_2cta_tma(device): with pytest.raises(Exception) as e: run_async_dot_blackwell_2cta_tma(device, False, 128) assert isinstance(e.value, triton.CompilationError), "expecting a compilation error" - assert 'only supports M=128 per CTA for pair-CTA mma' in e.value.error_message + assert "only supports M=128 per CTA for pair-CTA mma" in e.value.error_message def run_async_dot_blackwell_2cta_tma(device, A_TMEM, SAMPLE_M): @@ -2703,6 +2703,103 @@ def test_wait_arrive_ws(BLOCK_SIZE, device): and (ttgir.count("default {") == 1) and (ttgir.count("partition0") == 1)), f"TTGIR {ttgir}" +@triton.jit +def tlx_square_warp_barrier( + x_ptr, + z_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + NUM_WARPS: tl.constexpr, +): + """ + Warp-specialized kernel demonstrating perThread barrier arrives with SMEM. + Producer loads global → stores SMEM → arrives (perThread, no bar.sync). + Consumer waits → loads SMEM → computes z=x*x → stores global → arrives. + + This mirrors the GEMM epilogue pattern where local_load from shared memory + is followed by barrier_arrive to signal the buffer is consumed. + """ + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + + # Warp barriers: each thread arrives independently (no leader sync) + bars = tlx.alloc_warp_barrier(num_barriers=2, num_warps=NUM_WARPS) + b0 = tlx.local_view(bars, 0) + b1 = tlx.local_view(bars, 1) + + # Shared memory buffer for producer-consumer data transfer + buf = tlx.local_alloc((BLOCK_SIZE, ), tl.float32, 1) + smem = tlx.local_view(buf, 0) + + phase = 0 + with tlx.async_tasks(): + with tlx.async_task("default"): + tlx.barrier_wait(bar=b1, phase=phase ^ 1) + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Producer: load from global, store to SMEM + x = tl.load(x_ptr + offsets, mask=mask) + tlx.local_store(smem, x) + # KEY PATTERN: SMEM write → perThread arrive (no bar.sync) + tlx.barrier_arrive(bar=b0) + + with tlx.async_task(num_warps=4): + tlx.barrier_wait(bar=b0, phase=phase) + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + # Consumer: load from SMEM, compute, store to global + data = tlx.local_load(smem) + z = data * data + tl.store(z_ptr + offsets, z, mask=mask) + # KEY PATTERN: SMEM read → perThread arrive (no bar.sync) + tlx.barrier_arrive(bar=b0) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer") +@pytest.mark.parametrize("BLOCK_SIZE", [(1024)]) +@pytest.mark.parametrize("num_warps", [4]) +def test_alloc_warp_barrier(BLOCK_SIZE, num_warps, device): + torch.manual_seed(0) + size = 98432 + x = torch.rand(size, device=device) + z = torch.empty_like(x) + n_elements = x.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + kernel = tlx_square_warp_barrier[grid]( + x, + z, + n_elements, + BLOCK_SIZE, + num_warps, + num_warps=num_warps, + ) + + z_ref = x * x + torch.testing.assert_close(z, z_ref, check_dtype=False) + + # Verify TTGIR: warp-specialized with perThread arrives + ttgir = kernel.asm["ttgir"] + assert "perThread" in ttgir, f"Expected perThread attrs in TTGIR:\n{ttgir}" + assert "ttng.arrive_barrier" in ttgir, f"Expected arrive_barrier in TTGIR:\n{ttgir}" + + # Verify LLIR: perThread arrives use per-thread lowering (no leader predicate) + llir = kernel.asm["llir"] + # Per-thread arrive emits unpredicated: mbarrier.arrive.shared::cta.b64 _, [$0] + assert "mbarrier.arrive.shared::cta.b64 _, [$0]" in llir, ( + f"Expected unpredicated per-thread mbarrier.arrive in LLIR:\n{llir}") + # Leader pattern would emit predicated: @$0 mbarrier.arrive + assert "@$0 mbarrier.arrive" not in llir, f"Unexpected leader-predicated mbarrier.arrive in LLIR:\n{llir}" + # No bar.sync immediately before mbarrier.arrive (membar pass should skip + # perThread arrives for both full-range and per-buffer SMEM hazards). + # Other bar.sync may exist (e.g. before wait_barrier) — that's fine. + + assert not re.search(r"barrier\.cta\.sync.*\n.*mbarrier\.arrive", + llir), (f"Unexpected bar.sync before mbarrier.arrive in LLIR:\n{llir}") + + @pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer") def test_barrier_live_range(device): @@ -6382,13 +6479,13 @@ def bulk_copy_kernel( ttgir = kernel.asm["ttgir"] assert "ttg.async_copy_global_to_local" in ttgir, "Expected async_copy_global_to_local in TTGIR" assert "useBulk = true" in ttgir, "Expected useBulk = true in TTGIR" - assert "ttng.async_store" in ttgir, ("Expected async_store in TTGIR") + assert "ttng.async_store" in ttgir, "Expected async_store in TTGIR" # Verify PTX contains the bulk copy instructions ptx = kernel.asm["ptx"] assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes" in ptx, ( "Expected cp.async.bulk gmem->smem in PTX") - assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, ("Expected cp.async.bulk smem->gmem in PTX") + assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx, "Expected cp.async.bulk smem->gmem in PTX" # Verify correctness torch.testing.assert_close(src, dst) diff --git a/test/Analysis/test-membar-ttng.mlir b/test/Analysis/test-membar-ttng.mlir index 8afacae130..d67559d4ab 100644 --- a/test/Analysis/test-membar-ttng.mlir +++ b/test/Analysis/test-membar-ttng.mlir @@ -194,3 +194,52 @@ module attributes {"ttg.num-warps" = 4 : i32} { tt.return } } + +// ----- + +// Verify that a perThread arrive after a shared memory write does NOT get a +// gpu.barrier inserted before it. The perThread attribute opts out of the +// CTA-wide fence because each thread's program order guarantees its own SMEM +// ops complete before its arrive. + +#shared_pt = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#blocked_pt = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED_pt = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} { +// CHECK-LABEL: @no_barrier_before_perthread_arrive +tt.func @no_barrier_before_perthread_arrive(%arg: tensor<32x16xf16, #blocked_pt>) { + %alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable> + // CHECK: ttg.local_store + // CHECK-NEXT: ttng.arrive_barrier + // CHECK-NOT: gpu.barrier + // CHECK: tt.return + ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_pt> -> !ttg.memdesc<32x16xf16, #A_SHARED_pt, #ttg.shared_memory, mutable> + ttng.arrive_barrier %barrier, 1 {perThread} : !ttg.memdesc<1xi64, #shared_pt, #ttg.shared_memory, mutable> + tt.return +} +} + +// ----- + +// Verify that a regular (non-perThread) arrive after a shared memory write +// DOES get a gpu.barrier inserted before it (existing behavior preserved). + +#shared_reg = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#blocked_reg = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#A_SHARED_reg = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 1024 : i32} { +// CHECK-LABEL: @barrier_before_regular_arrive +tt.func @barrier_before_regular_arrive(%arg: tensor<32x16xf16, #blocked_reg>) { + %alloc = ttg.local_alloc : () -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable> + %barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable> + // CHECK: ttg.local_store + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: ttng.arrive_barrier + ttg.local_store %arg, %alloc : tensor<32x16xf16, #blocked_reg> -> !ttg.memdesc<32x16xf16, #A_SHARED_reg, #ttg.shared_memory, mutable> + ttng.arrive_barrier %barrier, 1 : !ttg.memdesc<1xi64, #shared_reg, #ttg.shared_memory, mutable> + tt.return +} +} diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index c36a6368e0..87ea1ce7e7 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -68,6 +68,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } + // CHECK-LABEL: arrive_barrier_per_thread + tt.func @arrive_barrier_per_thread(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) { + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + // CHECK-NOT: llvm.icmp "eq" + // CHECK: "mbarrier.arrive.shared::cta.b64 _, [$0], 2;", "r" %arg0 + ttng.arrive_barrier %alloc, 2 {perThread} : !ttg.memdesc<1xi64, #shared0, #smem> + tt.return + } + // CHECK-LABEL: arrive_barrier_named tt.func @arrive_barrier_named(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { %c9_i32 = arith.constant 9 : i32 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index a7e8d10747..008e43852a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -254,40 +254,69 @@ struct ArriveBarrierOpConversion LogicalResult matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + bool isPerThread = op.getPerThread(); + bool isRemoteBarrier = false; if (auto barType = dyn_cast(op.getAlloc().getType())) { isRemoteBarrier = isa(barType.getMemorySpace()); } - // TODO: Add phase result as needed. - std::stringstream ptxAsm; - ptxAsm << "@$0 mbarrier.arrive.shared::"; - if (isRemoteBarrier) - ptxAsm << "cluster"; - else - ptxAsm << "cta"; - ptxAsm << ".b64 _, [$1]"; - if (op.getCount() > 1) { - ptxAsm << ", " << op.getCount(); - } - ptxAsm << ";"; - - TritonLLVMOpBuilder b(op.getLoc(), rewriter); - Value id = getThreadId(rewriter, op.getLoc()); - Value pred = b.icmp_eq(id, b.i32_val(0)); - if (op.getPred()) - pred = b.and_(pred, adaptor.getPred()); + if (isPerThread) { + // Warp arrive: every thread arrives independently, no leader pattern. + bool hasPred = !!op.getPred(); + std::stringstream ptxAsm; + if (hasPred) { + ptxAsm << "@$0 "; + } + ptxAsm << "mbarrier.arrive.shared::cta.b64 _, [" + << (hasPred ? "$1" : "$0") << "]"; + if (op.getCount() > 1) { + ptxAsm << ", " << op.getCount(); + } + ptxAsm << ";"; - PTXBuilder ptxBuilder; - SmallVector operands = { - ptxBuilder.newOperand(pred, "b"), - ptxBuilder.newOperand(adaptor.getAlloc(), "r")}; + PTXBuilder ptxBuilder; + SmallVector operands; + if (hasPred) { + operands.push_back(ptxBuilder.newOperand(adaptor.getPred(), "b")); + } + operands.push_back(ptxBuilder.newOperand(adaptor.getAlloc(), "r")); - auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); - arriveOp(operands, /*onlyAttachMLIRArgs=*/true); - auto voidTy = void_ty(getContext()); - ptxBuilder.launch(rewriter, op.getLoc(), voidTy); + auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); + arriveOp(operands, /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(getContext()); + ptxBuilder.launch(rewriter, op.getLoc(), voidTy); + } else { + // Leader pattern: only thread 0 arrives. + std::stringstream ptxAsm; + ptxAsm << "@$0 mbarrier.arrive.shared::"; + if (isRemoteBarrier) + ptxAsm << "cluster"; + else + ptxAsm << "cta"; + ptxAsm << ".b64 _, [$1]"; + if (op.getCount() > 1) { + ptxAsm << ", " << op.getCount(); + } + ptxAsm << ";"; + + TritonLLVMOpBuilder b(op.getLoc(), rewriter); + Value id = getThreadId(rewriter, op.getLoc()); + Value pred = b.icmp_eq(id, b.i32_val(0)); + if (op.getPred()) + pred = b.and_(pred, adaptor.getPred()); + + PTXBuilder ptxBuilder; + SmallVector operands = { + ptxBuilder.newOperand(pred, "b"), + ptxBuilder.newOperand(adaptor.getAlloc(), "r")}; + + auto arriveOp = *ptxBuilder.create<>(ptxAsm.str()); + arriveOp(operands, /*onlyAttachMLIRArgs=*/true); + auto voidTy = void_ty(getContext()); + ptxBuilder.launch(rewriter, op.getLoc(), voidTy); + } rewriter.eraseOp(op); return success(); diff --git a/third_party/tlx/dialect/triton_tlx.cc b/third_party/tlx/dialect/triton_tlx.cc index 5ce0404313..63ed2a12c5 100644 --- a/third_party/tlx/dialect/triton_tlx.cc +++ b/third_party/tlx/dialect/triton_tlx.cc @@ -320,6 +320,12 @@ void init_triton_tlx_ir(py::module &&m) { [](TritonOpBuilder &self, Value mbarrerLoc, int arriveCount) -> void { self.create(mbarrerLoc, arriveCount); }) + .def("create_warp_barrier_arrive", + [](TritonOpBuilder &self, Value mbarrierLoc, + int arriveCount) -> void { + self.create(mbarrierLoc, arriveCount, + /*perThread=*/true); + }) .def("create_named_barrier_wait", [](TritonOpBuilder &self, Value barrier, Value numThreads) -> void { self.create(barrier, numThreads); diff --git a/third_party/tlx/language/tlx/__init__.py b/third_party/tlx/language/tlx/__init__.py index 6ba2a753c1..1607173f08 100644 --- a/third_party/tlx/language/tlx/__init__.py +++ b/third_party/tlx/language/tlx/__init__.py @@ -2,6 +2,7 @@ from .async_task_utils import async_task, async_tasks from .barrier import ( alloc_barriers, + alloc_warp_barrier, barrier_arrive, barrier_expect_bytes, barrier_wait, @@ -138,6 +139,7 @@ # barriers "cluster_barrier", "alloc_barriers", + "alloc_warp_barrier", "barrier_expect_bytes", "barrier_wait", "barrier_arrive", diff --git a/third_party/tlx/language/tlx/barrier.py b/third_party/tlx/language/tlx/barrier.py index bd673374b6..807195f64b 100644 --- a/third_party/tlx/language/tlx/barrier.py +++ b/third_party/tlx/language/tlx/barrier.py @@ -40,6 +40,47 @@ def alloc_barriers( ) +@tl.builtin +def alloc_warp_barrier( + num_barriers: tl.constexpr, + num_warps: tl.constexpr = tl.constexpr(1), + num_arrivals: tl.constexpr = tl.constexpr(1), + _semantic=None, +) -> tlx.mbarrier: + """ + Allocates warp barriers where all threads arrive independently. + + Unlike alloc_barriers (where a single leader thread signals the arrive after + a warp sync), warp barriers expect every thread to arrive individually. This + removes the need for thread synchronization before the arrive, reducing + unnecessary syncs and improving performance when there is warp divergence. + + Input: + - `num_barriers`: The number of barriers to allocate. + - `num_warps`: The number of warps whose threads will arrive at the barrier. + - `num_arrivals`: The number of times barrier_arrive is called per phase. + The total arrive count is num_warps * 32 * num_arrivals. + """ + + arrive_count = num_warps.value * 32 * num_arrivals.value + layout = tlx.swizzled_shared_layout_encoding.make_default(rank=1) + layout_handle = _semantic.builder.make_swizzled_shared_encoding_attr( + layout.vectorSize, + layout.perPhase, + layout.maxPhase, + layout.order, + layout.numCTAsPerCGA, + layout.numCTASplit, + layout.numCTAOrder, + ) + return tlx.mbarrier( + _semantic.builder.create_alloc_barriers(num_barriers.value, arrive_count, layout_handle), + num_barriers, + layout, + is_warp_barrier=True, + ) + + @tl.builtin def barrier_expect_bytes( bar: tlx.mbarrier, @@ -114,7 +155,11 @@ def barrier_arrive( if remote_cta_rank is not None: bar = remote_view(bar, remote_cta_rank, _semantic=_semantic) - _semantic.builder.create_barrier_arrive(bar.handle, arrive_count.value) + + if getattr(bar, 'is_warp_barrier', False): + _semantic.builder.create_warp_barrier_arrive(bar.handle, arrive_count.value) + else: + _semantic.builder.create_barrier_arrive(bar.handle, arrive_count.value) @tl.builtin diff --git a/third_party/tlx/language/tlx/mem_ops.py b/third_party/tlx/language/tlx/mem_ops.py index e368d89d6e..54a557455b 100644 --- a/third_party/tlx/language/tlx/mem_ops.py +++ b/third_party/tlx/language/tlx/mem_ops.py @@ -265,7 +265,8 @@ def local_view( buffer_idx = _semantic._convert_elem_to_ir_value(buffer_idx, require_i64=False) view_handle = _semantic.builder.create_memdesc_subview(local_allocated_buffers.handle, buffer_idx) if isinstance(local_allocated_buffers, tlx.mbarrier): - return tlx.mbarrier(view_handle, 0, local_allocated_buffers.type.layout) + return tlx.mbarrier(view_handle, 0, local_allocated_buffers.type.layout, + is_warp_barrier=local_allocated_buffers.is_warp_barrier) elif isinstance(local_allocated_buffers, tlx.clc_response): return tlx.clc_response(view_handle, 0, local_allocated_buffers.type.layout) else: diff --git a/third_party/tlx/language/tlx/types.py b/third_party/tlx/language/tlx/types.py index 2586d0b03c..34a1caadd9 100644 --- a/third_party/tlx/language/tlx/types.py +++ b/third_party/tlx/language/tlx/types.py @@ -863,12 +863,14 @@ def __init__( num: int, layout: Optional[swizzled_shared_layout_encoding], storage: storage_kind = storage_kind.smem, + is_warp_barrier: bool = False, ): assert storage == storage_kind.smem or storage == storage_kind.smemCluster, ( "mbarrier requires storage to be smem or smemCluster") self.handle = handle - self.type = mbarrier_type(num, layout, storage) + self.type = mbarrier_type(num, layout, storage, is_warp_barrier) self.num = num + self.is_warp_barrier = is_warp_barrier def _flatten_ir(self, handles) -> None: handles.append(self.handle) @@ -883,11 +885,13 @@ def _unflatten_ir(self, handles, cursor): class mbarrier_type(buffered_tensor_type): - def __init__(self, num: int, layout: Optional[swizzled_shared_layout_encoding], storage): + def __init__(self, num: int, layout: Optional[swizzled_shared_layout_encoding], storage, + is_warp_barrier: bool = False): super().__init__(tl.int64, [1], num, storage, layout) + self.is_warp_barrier = is_warp_barrier def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[mbarrier, int]: - value = mbarrier(handles[cursor], self.num, self.layout, self.storage) + value = mbarrier(handles[cursor], self.num, self.layout, self.storage, is_warp_barrier=self.is_warp_barrier) return value, cursor + 1 def to_ir(self, builder: ir.builder) -> None: diff --git a/third_party/tlx/tutorials/blackwell_gemm_clc.py b/third_party/tlx/tutorials/blackwell_gemm_clc.py index b78c68795f..7ea0ee071d 100644 --- a/third_party/tlx/tutorials/blackwell_gemm_clc.py +++ b/third_party/tlx/tutorials/blackwell_gemm_clc.py @@ -70,6 +70,7 @@ def matmul_kernel_tma_ws_blackwell_clc(a_desc, b_desc, c_desc, M, N, K, BLOCK_SI NUM_SMS: tl.constexpr, # NUM_CLC_STAGES: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # + USE_WARP_BARRIER: tl.constexpr = False, # ): # allocate NUM_SMEM_BUFFERS buffers buffers_A = tlx.local_alloc((BLOCK_SIZE_M, BLOCK_SIZE_K), tlx.dtype_of(a_desc), NUM_SMEM_BUFFERS) @@ -80,8 +81,12 @@ def matmul_kernel_tma_ws_blackwell_clc(a_desc, b_desc, c_desc, M, N, K, BLOCK_SI # allocate barriers smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) - tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) - tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + if USE_WARP_BARRIER: + tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=1) + tmem_empty_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS, num_warps=4) + else: + tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) clc_context = tlx.clc_create_context(num_consumers=3) @@ -241,7 +246,7 @@ def matmul_kernel_tma_ws_blackwell_clc(a_desc, b_desc, c_desc, M, N, K, BLOCK_SI clc_phase_consumer ^= 1 -def matmul(a, b, config=None): +def matmul(a, b, config=None, use_warp_barrier=False): """Matrix multiplication using TLX GEMM kernel.""" # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" @@ -277,6 +282,7 @@ def matmul(a, b, config=None): K, NUM_SMS=NUM_SMS, NUM_CLC_STAGES=1, + USE_WARP_BARRIER=use_warp_barrier, **config, ) else: @@ -290,6 +296,11 @@ def matmul(a, b, config=None): K, NUM_SMS=NUM_SMS, NUM_CLC_STAGES=1, + USE_WARP_BARRIER=use_warp_barrier, ) return c + + +def matmul_warp_barrier(a, b, config=None): + return matmul(a, b, config=config, use_warp_barrier=True) diff --git a/third_party/tlx/tutorials/blackwell_gemm_ws.py b/third_party/tlx/tutorials/blackwell_gemm_ws.py index 11324d2d03..2802353ab8 100644 --- a/third_party/tlx/tutorials/blackwell_gemm_ws.py +++ b/third_party/tlx/tutorials/blackwell_gemm_ws.py @@ -980,6 +980,7 @@ def matmul_kernel_tma_ws_blackwell( SPLIT_K: tl.constexpr, INTERLEAVE_EPILOGUE: tl.constexpr, NUM_SMS: tl.constexpr, + USE_WARP_BARRIER: tl.constexpr = False, ): # allocate NUM_SMEM_BUFFERS buffers BLOCK_M_SPLIT: tl.constexpr = BLOCK_SIZE_M // NUM_MMA_GROUPS @@ -1026,8 +1027,14 @@ def matmul_kernel_tma_ws_blackwell( A_smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1) B_smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) # NUM_TMEM_BUFFERS (overlaps MMA and epilogue) - tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1) - tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=EPILOGUE_SUBTILE) + if USE_WARP_BARRIER: + tmem_full_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, num_warps=1) + tmem_empty_bars = tlx.alloc_warp_barrier(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, num_warps=4, + num_arrivals=EPILOGUE_SUBTILE) + else: + tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1) + tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, + arrive_count=EPILOGUE_SUBTILE) with tlx.async_tasks(): with tlx.async_task("default"): # epilogue consumer @@ -1197,7 +1204,7 @@ def matmul_kernel_tma_ws_blackwell( tile_id += NUM_SMS -def matmul(a, b, config=None, use_heuristic=False): +def matmul(a, b, config=None, use_heuristic=False, use_warp_barrier=False): """Matrix multiplication using TLX GEMM kernel. Args: @@ -1261,6 +1268,7 @@ def matmul(a, b, config=None, use_heuristic=False): N, K, NUM_SMS=NUM_SMS, + USE_WARP_BARRIER=use_warp_barrier, ctas_per_cga=ctas_per_cga, **config, ) @@ -1284,5 +1292,10 @@ def grid(META): N, K, NUM_SMS=NUM_SMS, + USE_WARP_BARRIER=use_warp_barrier, ) return c + + +def matmul_warp_barrier(a, b, config=None, use_heuristic=True): + return matmul(a, b, config=config, use_heuristic=use_heuristic, use_warp_barrier=True) diff --git a/third_party/tlx/tutorials/hopper_gemm_ws.py b/third_party/tlx/tutorials/hopper_gemm_ws.py index 1d4f70cd81..cdb69f65fa 100644 --- a/third_party/tlx/tutorials/hopper_gemm_ws.py +++ b/third_party/tlx/tutorials/hopper_gemm_ws.py @@ -69,6 +69,7 @@ def matmul_kernel_tlx_ws(a_desc, b_desc, c_desc, # NUM_MMA_WARPS: tl.constexpr, # NUM_MMA_GROUPS: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # + USE_WARP_BARRIER: tl.constexpr = False, # ): # Descriptor BLOCK_M_SPLIT: tl.constexpr = BM // NUM_MMA_GROUPS @@ -82,9 +83,13 @@ def matmul_kernel_tlx_ws(a_desc, b_desc, c_desc, # # Need NUM_STAGES sets of mbarriers for A and B # where each set contains two for A and one for B. # Do the above for both empty states and full states respectively. - bars_empty_a = tlx.alloc_barriers(num_barriers=NUM_STAGES * NUM_MMA_GROUPS, arrive_count=1) + if USE_WARP_BARRIER: + bars_empty_a = tlx.alloc_warp_barrier(num_barriers=NUM_STAGES * NUM_MMA_GROUPS, num_warps=4) + bars_empty_b = tlx.alloc_warp_barrier(num_barriers=NUM_STAGES, num_warps=4, num_arrivals=NUM_MMA_GROUPS) + else: + bars_empty_a = tlx.alloc_barriers(num_barriers=NUM_STAGES * NUM_MMA_GROUPS, arrive_count=1) + bars_empty_b = tlx.alloc_barriers(num_barriers=NUM_STAGES, arrive_count=NUM_MMA_GROUPS) bars_full_a = tlx.alloc_barriers(num_barriers=NUM_STAGES * NUM_MMA_GROUPS, arrive_count=1) - bars_empty_b = tlx.alloc_barriers(num_barriers=NUM_STAGES, arrive_count=NUM_MMA_GROUPS) bars_full_b = tlx.alloc_barriers(num_barriers=NUM_STAGES, arrive_count=1) # Warp specilization @@ -198,7 +203,7 @@ def matmul_kernel_tlx_ws(a_desc, b_desc, c_desc, # c_desc.store([offset_cm, offset_bn], acc.to(tlx.dtype_of(c_desc))) # noqa -def matmul(a, b, config=None): +def matmul(a, b, config=None, use_warp_barrier=False): """Matrix multiplication using TLX GEMM kernel.""" # Check constraints. assert a.shape[1] == b.shape[0], "Illegal dimensions of input operands" @@ -253,6 +258,7 @@ def matmul(a, b, config=None): M, N, K, + USE_WARP_BARRIER=use_warp_barrier, **config, ) else: @@ -265,5 +271,10 @@ def matmul(a, b, config=None): M, N, K, + USE_WARP_BARRIER=use_warp_barrier, ) return c + + +def matmul_warp_barrier(a, b, config=None): + return matmul(a, b, config=config, use_warp_barrier=True) diff --git a/third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py b/third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py index 214f992ce0..76e5fc400e 100644 --- a/third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py +++ b/third_party/tlx/tutorials/testing/test_blackwell_gemm_perf.py @@ -5,9 +5,13 @@ import triton from triton.language.extra.tlx.tutorials.blackwell_gemm_ws import ( - matmul as _tlx_matmul_ws, ) + matmul as _tlx_matmul_ws, + matmul_warp_barrier as _tlx_matmul_ws_warp_barrier, +) from triton.language.extra.tlx.tutorials.blackwell_gemm_clc import ( - matmul as _tlx_matmul_clc, ) + matmul as _tlx_matmul_clc, + matmul_warp_barrier as _tlx_matmul_clc_warp_barrier, +) from triton.language.extra.tlx.tutorials.blackwell_gemm_pipelined import ( matmul as _tlx_matmul_pipelined, ) from triton.language.extra.tlx.tutorials.blackwell_gemm_2cta import ( @@ -20,7 +24,9 @@ # Registry of available matmul implementations MATMUL_METHODS = { "ws": _tlx_matmul_ws, + "ws_warp_barrier": _tlx_matmul_ws_warp_barrier, "clc": _tlx_matmul_clc, + "clc_warp_barrier": _tlx_matmul_clc_warp_barrier, "pipelined": _tlx_matmul_pipelined, "2cta": _tlx_matmul_2cta, } @@ -73,8 +79,9 @@ def benchmark(M, N, K, provider): parser.add_argument( "--version", type=str, + nargs="+", choices=list(MATMUL_METHODS.keys()), - help=f"Run only the specified version. Choices: {list(MATMUL_METHODS.keys())}", + help=f"Run only the specified version(s). Choices: {list(MATMUL_METHODS.keys())}", ) parser.add_argument( "--dtype", @@ -88,7 +95,7 @@ def benchmark(M, N, K, provider): dtype = {"fp16": torch.float16, "bf16": torch.bfloat16}[args.dtype] if is_blackwell(): - versions = [args.version] if args.version else list(MATMUL_METHODS.keys()) + versions = args.version if args.version else list(MATMUL_METHODS.keys()) print(f"Running benchmarks for: {versions} (dtype={args.dtype})") benchmark = create_benchmark(versions, dtype=dtype) benchmark.run(print_data=True) diff --git a/third_party/tlx/tutorials/testing/test_correctness.py b/third_party/tlx/tutorials/testing/test_correctness.py index aa7ddd28be..cf738f298b 100644 --- a/third_party/tlx/tutorials/testing/test_correctness.py +++ b/third_party/tlx/tutorials/testing/test_correctness.py @@ -5,9 +5,13 @@ import triton from triton.language.extra.tlx.tutorials.blackwell_gemm_ws import ( - matmul as _blackwell_gemm_ws, ) + matmul as _blackwell_gemm_ws, + matmul_warp_barrier as _blackwell_gemm_ws_warp_barrier, +) from triton.language.extra.tlx.tutorials.blackwell_gemm_clc import ( - matmul as _blackwell_gemm_clc, ) + matmul as _blackwell_gemm_clc, + matmul_warp_barrier as _blackwell_gemm_clc_warp_barrier, +) from triton.language.extra.tlx.tutorials.blackwell_gemm_pipelined import ( matmul as _blackwell_gemm_pipelined, ) from triton.language.extra.tlx.tutorials.blackwell_gemm_2cta import ( @@ -27,7 +31,9 @@ from triton.language.extra.tlx.tutorials.hopper_gemm_pipelined import ( matmul as _hopper_gemm_pipelined, ) from triton.language.extra.tlx.tutorials.hopper_gemm_ws import ( - matmul as _hopper_gemm_ws, ) + matmul as _hopper_gemm_ws, + matmul_warp_barrier as _hopper_gemm_ws_warp_barrier, +) from triton.language.extra.tlx.tutorials.hopper_fa_ws_pipelined_pingpong_persistent import ( attention as _hopper_fa_ws_pipelined_pingpong_persistent, ) from triton.language.extra.tlx.tutorials.hopper_fa_ws_pipelined_pingpong import ( @@ -231,6 +237,18 @@ def test_blackwell_gemm_clc(dtype): Gemm.run_test(_blackwell_gemm_clc, Gemm.CONFIGS["blackwell_gemm_clc"], dtype=dtype) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU") +def test_blackwell_gemm_ws_warp_barrier(dtype): + Gemm.run_test(_blackwell_gemm_ws_warp_barrier, Gemm.CONFIGS["blackwell_gemm_ws"], dtype=dtype) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU") +def test_blackwell_gemm_clc_warp_barrier(dtype): + Gemm.run_test(_blackwell_gemm_clc_warp_barrier, Gemm.CONFIGS["blackwell_gemm_clc"], dtype=dtype) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU") def test_blackwell_gemm_pipelined(dtype): @@ -341,6 +359,11 @@ def test_hopper_gemm_ws(): Gemm.run_test(_hopper_gemm_ws, Gemm.CONFIGS["hopper_gemm_ws"]) +@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper GPU") +def test_hopper_gemm_ws_warp_barrier(): + Gemm.run_test(_hopper_gemm_ws_warp_barrier, Gemm.CONFIGS["hopper_gemm_ws"]) + + # ============================================================================= # Hopper Flash Attention Tests # ============================================================================= diff --git a/third_party/tlx/tutorials/testing/test_hopper_gemm_perf.py b/third_party/tlx/tutorials/testing/test_hopper_gemm_perf.py index d40b8f5832..960161c561 100644 --- a/third_party/tlx/tutorials/testing/test_hopper_gemm_perf.py +++ b/third_party/tlx/tutorials/testing/test_hopper_gemm_perf.py @@ -7,7 +7,9 @@ from triton.language.extra.tlx.tutorials.hopper_gemm_pipelined import ( matmul as _matmul_pipelined, ) from triton.language.extra.tlx.tutorials.hopper_gemm_ws import ( - matmul as _matmul_ws, ) + matmul as _matmul_ws, + matmul_warp_barrier as _matmul_ws_warp_barrier, +) from triton._internal_testing import is_hopper @@ -16,6 +18,7 @@ MATMUL_METHODS = { "pipelined": _matmul_pipelined, "ws": _matmul_ws, + "ws_warp_barrier": _matmul_ws_warp_barrier, } ref_lib = "cuBLAS" @@ -65,13 +68,14 @@ def benchmark(M, N, K, provider): parser.add_argument( "--version", type=str, + nargs="+", choices=list(MATMUL_METHODS.keys()), - help=f"Run only the specified version. Choices: {list(MATMUL_METHODS.keys())}", + help=f"Run only the specified version(s). Choices: {list(MATMUL_METHODS.keys())}", ) args = parser.parse_args() if is_hopper(): - versions = [args.version] if args.version else list(MATMUL_METHODS.keys()) + versions = args.version if args.version else list(MATMUL_METHODS.keys()) print(f"Running benchmarks for: {versions}") benchmark = create_benchmark(versions) benchmark.run(print_data=True)