Skip to content

Commit 3c893cf

Browse files
authored
[Gluon][TTNG] Add async_copy ops including mbarrier arrive op (#7220)
This adds: - `ttgl.nvidia.ampere.async_copy.async_copy_global_to_shared` - `ttgl.nvidia.ampere.async_copy.mbarrier_arrive` - `ttgl.nvidia.ampere.async_copy.commit_group` - `ttgl.nvidia.ampere.async_copy.wait_group` - `ttgl.max_constancy` - `ttgl.max_contiguous` - `ttgl.multiple_of` Plus adding a new `ttng.async_copy_mbarrier_arrive` op to allow mbarrier synchronization on `cp.async` ops. The interface on this one is a bit odd since it looks like you can pass any pointer and mask like a normal load, but it will fail to convert to llvm if the layout can't be proven to be compatible with `cp.async`. Hence why I exposed the axis analysis annotation ops. There might be a better interface where you explicitly load chunks of 4, 8, or 16 bytes so we don't rely on axis analysis. This seems okay as a first draft though.
1 parent 5a311bb commit 3c893cf

File tree

19 files changed

+353
-85
lines changed

19 files changed

+353
-85
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
8888
let description = [{
8989
This operation copies data from global memory to local memory asynchronously.
9090
This is analogue to tt.load except the data are copied to local memory pointed
91-
by by the memory descriptor instead of a distributed tensor. The rest of the
91+
to by the memory descriptor instead of a distributed tensor. The rest of the
9292
operands are the same as tt.load.
9393
}];
9494

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
262262
let hasVerifier = 1;
263263
}
264264

265+
def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive"> {
266+
let summary = "arrive on mbarrier once all previously issued copies are completed";
267+
let arguments = (ins
268+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
269+
UnitAttr:$noIncrement
270+
);
271+
let assemblyFormat = "$barrier attr-dict `:` qualified(type($barrier))";
272+
}
273+
265274

266275
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local"> {
267276
let summary = "copy data based on descriptor from global memory to local memory asynchronously";

python/src/gluon_ir.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,29 @@ void init_gluon_ir(py::module &&m) {
279279
blockTy.getShape(), blockTy.getElementType(), layout);
280280
return triton::TensorDescType::get(ctx, blockTyLayout, isSigned);
281281
})
282+
.def("create_async_copy_global_to_local",
283+
[](GluonOpBuilder &self, Value smem, Value pointer, Value mask,
284+
tt::CacheModifier cacheModifier,
285+
tt::EvictionPolicy evictionPolicy, bool isVolatile) {
286+
self.create<ttg::AsyncCopyGlobalToLocalOp>(
287+
pointer, smem, mask, /*other*/ Value{}, cacheModifier,
288+
evictionPolicy, isVolatile);
289+
})
290+
.def("create_async_copy_mbarrier_arrive",
291+
[](GluonOpBuilder &self, Value mbarrier, bool incrementCount) {
292+
self.create<ttng::AsyncCopyMbarrierArriveOp>(mbarrier,
293+
!incrementCount);
294+
})
295+
.def("create_async_commit_group",
296+
[](GluonOpBuilder &self) {
297+
ValueRange tokens;
298+
self.create<ttg::AsyncCommitGroupOp>(tokens);
299+
})
300+
.def("create_async_wait_group",
301+
[](GluonOpBuilder &self, int num) {
302+
ValueRange tokens;
303+
self.create<ttg::AsyncWaitOp>(tokens, num);
304+
})
282305
.def("create_convert_layout",
283306
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
284307
return self.create<ttg::ConvertLayoutOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda
4+
from triton._internal_testing import is_ampere_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
7+
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
78
from triton.experimental.gluon.language.nvidia.hopper import tma
89

910

@@ -45,7 +46,7 @@ def tma_kernel(desc):
4546
alloc._keep_alive()
4647

4748

48-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires Hopper")
49+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
4950
def test_tma():
5051
out = torch.ones((16, 16), dtype=torch.float16, device="cuda")
5152
layout = ttgl.NVMMASharedLayout(
@@ -59,3 +60,36 @@ def test_tma():
5960
desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(out, [16, 16], layout)
6061
tma_kernel[(1, )](desc)
6162
torch.testing.assert_close(out, torch.zeros_like(out))
63+
64+
65+
@gluon.jit
66+
def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr):
67+
smem = ttgl.allocate_shared_memory(inp.dtype.element_ty, [XBLOCK, YBLOCK],
68+
ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]))
69+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0])
70+
xindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, block_layout))[:, None]
71+
yindex = ttgl.arange(0, YBLOCK, ttgl.SliceLayout(0, block_layout))[None, :]
72+
mask = xindex < xnumel
73+
async_copy.async_copy_global_to_shared(
74+
smem,
75+
inp + xindex * YBLOCK + yindex,
76+
mask,
77+
)
78+
mbar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
79+
mbarrier.init(mbar, count=1)
80+
async_copy.mbarrier_arrive(mbar)
81+
mbarrier.arrive(mbar)
82+
mbarrier.wait(mbar, 0)
83+
84+
val = smem.load(block_layout)
85+
ttgl.store(out + xindex * YBLOCK + yindex, val)
86+
87+
88+
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere")
89+
def test_async_copy_mbarrier():
90+
tensor_opts = dict(dtype=torch.float, device="cuda")
91+
out = torch.empty((32, 32), **tensor_opts)
92+
inp = torch.randn((20, 32), **tensor_opts)
93+
async_copy_mbarrier_kernel[(1, )](out, inp, inp.shape[0], XBLOCK=32, YBLOCK=32)
94+
torch.testing.assert_close(out[:20], inp)
95+
torch.testing.assert_close(out[20:], torch.zeros((12, 32), **tensor_opts))

python/test/gluon/test_frontend.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from triton.experimental import gluon
99
from triton.experimental.gluon import language as ttgl
1010
from triton.experimental.gluon.language.nvidia import blackwell
11-
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout
11+
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
1212
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
1313
from triton._filecheck import filecheck_test, run_parser
1414
import triton.language as tl
15-
from triton._internal_testing import is_cuda
15+
from triton._internal_testing import is_ampere_or_newer, is_blackwell, is_hopper
1616
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
1717

1818
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
@@ -117,8 +117,7 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
117117
buffers.index(i).load(layout)
118118

119119

120-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10,
121-
reason="Requires blackwell tensor cores")
120+
@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor cores")
122121
def test_tensor_memory(fresh_knobs):
123122
knobs.compilation.disable_line_info = True
124123

@@ -373,13 +372,13 @@ def mbarrier_kernel():
373372
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
374373
mbarrier.init(bar, count=1)
375374
mbarrier.expect(bar, 4)
376-
mbarrier.arrive(bar, 1)
375+
mbarrier.arrive(bar, count=1)
377376
phase = 0
378377
mbarrier.wait(bar, phase, deps=[bar])
379378
mbarrier.invalidate(bar)
380379

381380

382-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
381+
@pytest.mark.skipif(not is_hopper(), reason="Requires hopper or newer")
383382
def test_mbarrier(fresh_knobs):
384383
knobs.compilation.disable_line_info = True
385384

@@ -415,8 +414,7 @@ def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr)
415414
blackwell.tcgen05_mma(a, b, acc)
416415

417416

418-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10,
419-
reason="Requires blackwell tensor core")
417+
@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor core")
420418
def test_tcgen05_mma(fresh_knobs):
421419
knobs.compilation.disable_line_info = True
422420

@@ -460,7 +458,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
460458
tma.store_wait(0)
461459

462460

463-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="TMA requires at least Hopper")
461+
@pytest.mark.skipif(not is_hopper(), reason="TMA requires at least Hopper")
464462
def test_async_tma(fresh_knobs):
465463
knobs.compilation.disable_line_info = True
466464

@@ -519,7 +517,7 @@ def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr):
519517
tma.store_wait(0)
520518

521519

522-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, reason="Requires Blackwell")
520+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
523521
def test_async_tma_blackwell(fresh_knobs):
524522
knobs.compilation.disable_line_info = True
525523

@@ -955,3 +953,53 @@ def test_inline_asm_elementwise():
955953
x = ttgl.arange(0, 16, layout)
956954
# CHECK: elementwise_inline_asm {{.*}} : tensor<16xi32, [[BLOCKED:#.*]]> -> tensor<16xi32, [[BLOCKED]]>
957955
ttgl.inline_asm_elementwise("mov $0, $0;", "=r,r", [x], dtype=x.dtype, is_pure=True, pack=1)
956+
957+
958+
@gluon.jit
959+
def async_copy_kernel(inp, xnumel, XBLOCK: ttgl.constexpr):
960+
smem = ttgl.allocate_shared_memory(inp.dtype.element_ty, [XBLOCK], ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]))
961+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
962+
xindex = ttgl.arange(0, XBLOCK, block_layout)
963+
mask = tl.max_constancy(xindex < xnumel, 2)
964+
965+
async_copy.async_copy_global_to_shared(smem, inp + xindex, mask)
966+
async_copy.async_copy_global_to_shared(smem, inp + xindex, mask, cache_modifier=".ca", eviction_policy="evict_last",
967+
volatile=True)
968+
969+
mbar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
970+
async_copy.mbarrier_arrive(mbar)
971+
async_copy.mbarrier_arrive(mbar, increment_count=False)
972+
async_copy.commit_group()
973+
async_copy.wait_group(0)
974+
975+
976+
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires ampere")
977+
def test_async_copy(fresh_knobs):
978+
knobs.compilation.disable_line_info = True
979+
980+
h = async_copy_kernel.warmup(MockTensor(ttgl.float16), xnumel=100, XBLOCK=128, sanitize_overflow=False, grid=(1, ))
981+
expecttest.assert_expected_inline(
982+
anonymize_ir(h.asm["ttgir"]), """\
983+
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
984+
#loc = loc(unknown)
985+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
986+
#smem = #ttg.shared_memory
987+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
988+
tt.func public @async_copy_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown), %arg1: i32 loc(unknown)) attributes {noinline = false} {
989+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128xf16, #shared, #smem, mutable> loc(#loc)
990+
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc)
991+
%2 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked> loc(#loc)
992+
%3 = arith.cmpi slt, %1, %2 {tt.constancy = dense<2> : tensor<1xi32>} : tensor<128xi32, #blocked> loc(#loc)
993+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
994+
%5 = tt.addptr %4, %1 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked> loc(#loc)
995+
%6 = ttg.async_copy_global_to_local %5, %0 mask %3 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
996+
%7 = ttg.async_copy_global_to_local %5, %0 mask %3 cacheModifier = ca evictionPolicy = evict_last {isVolatile = true} : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
997+
%8 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
998+
ttng.async_copy_mbarrier_arrive %8 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
999+
ttng.async_copy_mbarrier_arrive %8 {noIncrement} : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
1000+
%9 = ttg.async_commit_group loc(#loc)
1001+
%10 = ttg.async_wait {num = 0 : i32} loc(#loc)
1002+
tt.return loc(#loc)
1003+
} loc(#loc)
1004+
} loc(#loc)
1005+
""")

python/triton/_internal_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def is_cuda():
3838
return False if target is None else target.backend == "cuda"
3939

4040

41+
def is_ampere_or_newer():
42+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 8
43+
44+
45+
def is_blackwell():
46+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
47+
48+
4149
def is_hopper():
4250
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4351

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@
4444

4545
_IMPORT_FROM_TRITON: List[str] = [
4646
"expand_dims",
47+
"inline_asm_elementwise",
4748
"join",
4849
"load",
4950
"maximum",
51+
"max_constancy",
52+
"max_contiguous",
5053
"minimum",
54+
"multiple_of",
5155
"permute",
5256
"program_id",
5357
"reduce",
@@ -58,7 +62,6 @@
5862
"store",
5963
"to_tensor",
6064
"where",
61-
"inline_asm_elementwise",
6265
]
6366

6467
__all__ = [
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import async_copy, mbarrier
2+
3+
__all__ = ["async_copy", "mbarrier"]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from ..._semantic import _check
2+
from ..._core import _unwrap_if_constexpr, builtin
3+
from triton._C.libtriton import ir
4+
5+
__all__ = [
6+
"async_copy_global_to_shared",
7+
"mbarrier_arrive",
8+
"commit_group",
9+
"wait_group",
10+
]
11+
12+
13+
@builtin
14+
def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", eviction_policy="", volatile=False,
15+
_semantic=None):
16+
mask = _unwrap_if_constexpr(mask)
17+
cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
18+
eviction_policy = _semantic._str_to_eviction_policy(eviction_policy)
19+
volatile = _unwrap_if_constexpr(volatile)
20+
if mask is not None:
21+
pointer, mask = _semantic.broadcast_impl_value(pointer, mask)
22+
_check(
23+
smem.shape == pointer.shape, lambda:
24+
f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}"
25+
)
26+
mask_handle = mask.handle if mask is not None else ir.value()
27+
_semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, cache_modifier,
28+
eviction_policy, volatile)
29+
30+
31+
@builtin
32+
def mbarrier_arrive(mbarrier, increment_count=True, _semantic=None):
33+
"""Arrive on the mbarrier once all outstanding async copies are complete.
34+
"""
35+
increment_count = _unwrap_if_constexpr(increment_count)
36+
_semantic.builder.create_async_copy_mbarrier_arrive(mbarrier.handle, increment_count)
37+
38+
39+
@builtin
40+
def commit_group(_semantic=None):
41+
_semantic.builder.create_async_commit_group()
42+
43+
44+
@builtin
45+
def wait_group(num_outstanding=0, _semantic=None):
46+
num_outstanding = _unwrap_if_constexpr(num_outstanding)
47+
_semantic.builder.create_async_wait_group(num_outstanding)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3+
4+
__all__ = ["arrive", "init", "invalidate", "MBarrierLayout", "wait"]
5+
6+
7+
class MBarrierLayout(SwizzledSharedLayout):
8+
9+
def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
10+
super().__init__(
11+
vec=1,
12+
per_phase=1,
13+
max_phase=1,
14+
order=[0],
15+
ctas_per_cga=[ctas_per_cga],
16+
cta_split_num=[cta_split_num],
17+
cta_order=[0],
18+
)
19+
20+
21+
@builtin
22+
def init(mbarrier, count, _semantic=None):
23+
count = _unwrap_if_constexpr(count)
24+
_semantic.builder.create_mbarrier_init(mbarrier.handle, count)
25+
26+
27+
@builtin
28+
def invalidate(mbarrier, _semantic=None):
29+
_semantic.builder.create_mbarrier_inval(mbarrier.handle)
30+
31+
32+
@builtin
33+
def wait(mbarrier, phase, pred=True, deps=(), _semantic=None):
34+
phase = _semantic.to_tensor(phase)
35+
pred = _semantic.to_tensor(pred)
36+
deps = [x.handle for x in deps]
37+
_semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
38+
39+
40+
@builtin
41+
def arrive(mbarrier, *, pred=True, _semantic=None):
42+
count = 1
43+
pred = _semantic.to_tensor(pred)
44+
_semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)

0 commit comments

Comments
 (0)