Skip to content

Commit fd5fb0c

Browse files
[GLUON] Async WGMMA support (#7313)
Adding `warpgroup_mma_wait` op.
1 parent 1ab9f65 commit fd5fb0c

File tree

4 files changed

+47
-6
lines changed

4 files changed

+47
-6
lines changed

python/src/gluon_ir.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,10 @@ void init_gluon_ir(py::module &&m) {
391391
return self.create<ttng::WarpGroupDotOp>(
392392
a, b, acc, useAcc, precision, maxNumImpreciseAcc, isAsync);
393393
})
394-
394+
.def("create_warpgroup_mma_wait",
395+
[](GluonOpBuilder &self, std::vector<Value> &deps, int pendings) {
396+
self.create<ttng::WarpGroupDotWaitOp>(deps, pendings);
397+
})
395398
.def("create_tmem_alloc",
396399
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
397400
return self.create<ttng::TMEMAllocOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_async_copy_mbarrier():
100100

101101

102102
@gluon.jit
103-
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr):
103+
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, ASYNC: ttgl.constexpr):
104104
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
105105
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1],
106106
instr_shape=[16, 32, 16])
@@ -121,19 +121,23 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
121121
a_shmem = ttgl.allocate_shared_memory(ttgl.float16, [M, K], nvmma_layout, A)
122122
b_shmem = ttgl.allocate_shared_memory(ttgl.float16, [K, N], nvmma_layout, B)
123123

124-
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc)
124+
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc, is_async=ASYNC)
125+
126+
if ASYNC:
127+
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
125128

126129
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
127130

128131

129132
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
130-
def test_warpgroup_mma():
133+
@pytest.mark.parametrize("ASYNC", [True, False])
134+
def test_warpgroup_mma(ASYNC):
131135
torch.manual_seed(0)
132136
M, N, K = 64, 32, 32
133137
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
134138
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
135139
out = torch.zeros((M, N), device="cuda", dtype=torch.float16)
136-
warpgroup_mma_kernel[(1, )](a, b, out, M, N, K)
140+
warpgroup_mma_kernel[(1, )](a, b, out, M, N, K, ASYNC)
137141

138142
ref = torch.matmul(a, b)
139143

python/test/gluon/test_frontend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,33 @@ def test_warpgroup_mma(fresh_knobs):
482482
""")
483483

484484

485+
@gluon.jit
486+
def warpgroup_mma_wait_kernel():
487+
layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16])
488+
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=layout)
489+
hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
490+
491+
492+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper WGMMA")
493+
def test_warpgroup_mma_wait(fresh_knobs):
494+
knobs.compilation.disable_line_info = True
495+
496+
h = warpgroup_mma_wait_kernel.warmup(grid=(1, ))
497+
expecttest.assert_expected_inline(
498+
anonymize_ir(h.asm["source"]), """\
499+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
500+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
501+
tt.func public @warpgroup_mma_wait_kernel() attributes {noinline = false} {
502+
%cst = arith.constant 0.000000e+00 : f16 loc(#loc)
503+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #mma> loc(#loc)
504+
%0 = ttng.warp_group_dot_wait %cst_0 {pendings = 1 : i32} : tensor<128x128xf16, #mma> loc(#loc)
505+
tt.return loc(#loc)
506+
} loc(#loc)
507+
} loc(#loc)
508+
#loc = loc(unknown)
509+
""")
510+
511+
485512
@gluon.jit
486513
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
487514
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from . import mbarrier, tma
33
from ... import _core
44

5-
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma"]
5+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
66

77

88
@_core.builtin
@@ -25,3 +25,10 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
2525
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
2626
max_num_imprecise_acc, is_async)
2727
return _core.tensor(handle, acc.type)
28+
29+
30+
@_core.builtin
31+
def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
32+
deps = [x.handle for x in deps] if deps is not None else []
33+
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
34+
_semantic.builder.create_warpgroup_mma_wait(deps, num_outstanding)

0 commit comments

Comments
 (0)