Skip to content

Commit da23473

Browse files
pchen7e2meta-codesync[bot]
authored andcommitted
[6/N][TLX-2cta] Codegen for remote barrier arrive (#647)
Summary: When the barrier object is in cluster SMEM space, we should lower to arrive instruction with `shared::cluster` state space. The input bar typically is a result of a mapa op which was landed earlier in the stack. ``` % make test-lit ninja -C /data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11 check-triton-lit-tests ninja: Entering directory `/data/users/pchen7e4/triton/build/cmake.linux-x86_64-cpython-3.11' [0/1] Running the triton regression tests Testing Time: 9.11s Total Discovered Tests: 208 Passed : 207 (99.52%) Expectedly Failed: 1 (0.48%) % third_party/tlx/run_all.sh Hello! (Facebook-only) Need to build triton in this script? {y|n}n Run all LITs? {y|n}n Run core Triton python unit tests? {y|n}n Run all TLX unit tests? {y|n}y Running TLX Unit Tests =========================================================================================== test session starts ============================================================================================ platform linux -- Python 3.11.13, pytest-8.3.4, pluggy-1.5.0 rootdir: /data/users/pchen7e4/triton configfile: pyproject.toml plugins: xdist-3.7.0, forked-1.6.0, typeguard-4.3.0 collected 109 items python/test/unit/language/test_tlx.py ...sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.............s................. [100%] ===================================================================================== 33 passed, 76 skipped in 50.92s ====================================================================================== Run TLX tutorial kernels (correctness|performance|no)? {c|p|n} c Verifying correctness of TLX tutorial kernels =========================================================================================== test session starts ============================================================================================ platform linux -- Python 3.11.13, pytest-8.3.4, pluggy-1.5.0 rootdir: /data/users/pchen7e4/triton configfile: pyproject.toml plugins: xdist-3.7.0, forked-1.6.0, typeguard-4.3.0 collected 17 items third_party/tlx/tutorials/amd-gemm-pipelined.py s [ 5%] third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py . [ 11%] third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py . [ 17%] third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py . [ 23%] third_party/tlx/tutorials/blackwell-fa-ws_test.py . [ 29%] third_party/tlx/tutorials/blackwell-gemm-clc.py . [ 35%] third_party/tlx/tutorials/blackwell-gemm-pipelined.py . [ 41%] third_party/tlx/tutorials/blackwell-gemm-ws.py . [ 47%] third_party/tlx/tutorials/blackwell-grouped-gemm.py . [ 52%] third_party/tlx/tutorials/hopper-fa-ws-pipelined-pingpong_test.py s [ 58%] third_party/tlx/tutorials/hopper-fa-ws-pipelined_test.py s [ 64%] third_party/tlx/tutorials/hopper-fa-ws_test.py s [ 70%] third_party/tlx/tutorials/hopper-gemm-pipelined_test.py s [ 76%] third_party/tlx/tutorials/hopper-gemm-ws_test.py s [ 82%] third_party/tlx/tutorials/hopper-persistent-gemm-ws-cooperative.py s [ 88%] third_party/tlx/tutorials/hopper-persistent-gemm-ws-pingpong.py s [ 94%] third_party/tlx/tutorials/vector-add2.py . [100%] ... =========================================================================== 9 passed, 8 skipped, 4 warnings in 126.92s (0:02:06) =============== ``` Pull Request resolved: #647 Reviewed By: htyu Differential Revision: D86467280 Pulled By: pchen7e2 fbshipit-source-id: 8795aeb097d8fc284bd8fba84de994d88167adf7
1 parent 6f77f69 commit da23473

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/Conversion/tritonnvidiagpu_to_llvm.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7979
tt.return
8080
}
8181

82+
// CHECK-LABEL: arrive_barrier_remote
83+
tt.func @arrive_barrier_remote(%alloc: !ttg.memdesc<1xi64, #shared0, #ttng.shared_cluster_memory>, %pred: i1) {
84+
// CHECK: "@$0 mbarrier.arrive.shared::cluster.b64 _, [$1], 2;", "b,r" %{{.*}}
85+
ttng.arrive_barrier %alloc, 2, %pred : !ttg.memdesc<1xi64, #shared0, #ttng.shared_cluster_memory>
86+
tt.return
87+
}
88+
8289
// CHECK-LABEL: wait_barrier_named
8390
tt.func @wait_barrier_named(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) {
8491
%c9_i32 = arith.constant 9 : i32

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
using namespace mlir;
3535
using namespace mlir::triton;
36+
namespace ttg = mlir::triton::gpu;
37+
namespace ttng = mlir::triton::nvidia_gpu;
3638

3739
namespace {
3840
struct FenceAsyncSharedOpConversion
@@ -234,9 +236,20 @@ struct ArriveBarrierOpConversion
234236
LogicalResult
235237
matchAndRewrite(triton::nvidia_gpu::ArriveBarrierOp op, OpAdaptor adaptor,
236238
ConversionPatternRewriter &rewriter) const override {
239+
bool isRemoteBarrier = false;
240+
if (auto barType = dyn_cast<ttg::MemDescType>(op.getAlloc().getType())) {
241+
isRemoteBarrier =
242+
isa<ttng::SharedClusterMemorySpaceAttr>(barType.getMemorySpace());
243+
}
244+
237245
// TODO: Add phase result as needed.
238246
std::stringstream ptxAsm;
239-
ptxAsm << "@$0 mbarrier.arrive.shared::cta.b64 _, [$1]";
247+
ptxAsm << "@$0 mbarrier.arrive.shared::";
248+
if (isRemoteBarrier)
249+
ptxAsm << "cluster";
250+
else
251+
ptxAsm << "cta";
252+
ptxAsm << ".b64 _, [$1]";
240253
if (op.getCount() > 1) {
241254
ptxAsm << ", " << op.getCount();
242255
}

0 commit comments

Comments
 (0)