Skip to content

Commit 6aa49bb

Browse files
authored
[Gluon] Add mbarrier (#6997)
This implements: - `ttgl.SwizzledSharedLayout` - `ttgl.nvidia.hopper.mbarrier.MBarrierLayout` (convenience wrapper for `SwizzledSharedLayout`) - `ttgl.nvidia.hopper.mbarrier.init` - `ttgl.nvidia.hopper.mbarrier.invalidate` - `ttgl.nvidia.hopper.mbarrier.expect` - `ttgl.nvidia.hopper.mbarrier.wait` - `ttgl.nvidia.hopper.mbarrier.arrive` plus aliases in `ttgl.nvidia.blackwell.mbarrier` Note that I'm keeping this API functional to allow interpreting any shared allocation as an mbarrier. We can wrap with higher level APIs at a later date if desired.
1 parent d510a3d commit 6aa49bb

File tree

7 files changed

+177
-3
lines changed

7 files changed

+177
-3
lines changed

python/src/gluon_ir.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ void init_gluon_ir(py::module &&m) {
8080
ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded,
8181
ctaLayout);
8282
})
83+
.def("get_swizzled_shared_layout",
84+
[](GluonOpBuilder &self, int vec, int perPhase, int maxPhase,
85+
std::vector<unsigned> &order, std::vector<unsigned> &ctasPerCga,
86+
std::vector<unsigned> &ctaSplitNum,
87+
std::vector<unsigned> &ctaOrder) -> Attribute {
88+
auto ctx = self.getContext();
89+
auto ctaLayout = ttg::CTALayoutAttr::get(ctx, ctasPerCga,
90+
ctaSplitNum, ctaOrder);
91+
return ttg::SwizzledSharedEncodingAttr::get(
92+
ctx, vec, perPhase, maxPhase, order, ctaLayout);
93+
})
8394
.def("get_tensor_memory_layout",
8495
[](GluonOpBuilder &self, std::vector<unsigned> &block, bool unpacked,
8596
std::vector<unsigned> &ctaSplitNum) -> Attribute {
@@ -132,6 +143,27 @@ void init_gluon_ir(py::module &&m) {
132143
int N) -> Value {
133144
return self.create<ttng::TMEMSubSliceOp>(resultTy, memDesc, N);
134145
})
146+
.def("create_mbarrier_init",
147+
[](GluonOpBuilder &self, Value memDesc, int count) {
148+
self.create<ttng::InitBarrierOp>(memDesc, count);
149+
})
150+
.def("create_mbarrier_inval",
151+
[](GluonOpBuilder &self, Value memDesc) {
152+
self.create<ttng::InvalBarrierOp>(memDesc);
153+
})
154+
.def("create_mbarrier_expect",
155+
[](GluonOpBuilder &self, Value memDesc, int bytes, Value pred) {
156+
self.create<ttng::BarrierExpectOp>(memDesc, bytes, pred);
157+
})
158+
.def("create_mbarrier_wait",
159+
[](GluonOpBuilder &self, Value memDesc, Value phase, Value pred,
160+
std::vector<Value> &deps) {
161+
self.create<ttng::WaitBarrierOp>(memDesc, phase, pred, deps);
162+
})
163+
.def("create_mbarrier_arrive",
164+
[](GluonOpBuilder &self, Value memDesc, int count, Value pred) {
165+
self.create<ttng::ArriveBarrierOp>(memDesc, count, pred);
166+
})
135167
.def("create_warp_return",
136168
[](GluonOpBuilder &self) -> Operation * {
137169
return self.create<ttg::WarpReturnOp>();

python/test/gluon/test_frontend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton import knobs
66
from triton.experimental import gluon
77
from triton.experimental.gluon import language as ttgl
8+
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier
89
from triton._filecheck import filecheck_test
910
import triton.language as tl
1011
from triton._internal_testing import is_cuda
@@ -177,3 +178,42 @@ def test_warp_specialize():
177178
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
178179
anchor(a)
179180
anchor(b)
181+
182+
183+
@gluon.jit
184+
def mbarrier_kernel():
185+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
186+
mbarrier.init(bar, count=1)
187+
mbarrier.expect(bar, 4)
188+
mbarrier.arrive(bar, 1)
189+
phase = 0
190+
mbarrier.wait(bar, phase, deps=[bar])
191+
mbarrier.invalidate(bar)
192+
193+
194+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
195+
def test_mbarrier(fresh_knobs):
196+
knobs.compilation.disable_line_info = True
197+
198+
h = mbarrier_kernel.warmup(grid=(1, ))
199+
expecttest.assert_expected_inline(
200+
h.asm["ttgir"], """\
201+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
202+
#smem = #ttg.shared_memory
203+
module attributes {"ttg.num-warps" = 4 : i32} {
204+
tt.func public @mbarrier_kernel() attributes {noinline = false} {
205+
%0 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
206+
ttng.init_barrier %0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
207+
%true = arith.constant true loc(#loc)
208+
ttng.barrier_expect %0, 4, %true : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
209+
%true_0 = arith.constant true loc(#loc)
210+
ttng.arrive_barrier %0, 1, %true_0 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
211+
%c0_i32 = arith.constant 0 : i32 loc(#loc)
212+
%true_1 = arith.constant true loc(#loc)
213+
ttng.wait_barrier %0, %c0_i32, %true_1 deps %0 : !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
214+
ttng.inval_barrier %0 : !ttg.memdesc<1xi64, #shared, #smem, mutable> loc(#loc)
215+
tt.return loc(#loc)
216+
} loc(#loc)
217+
} loc(#loc)
218+
#loc = loc(unknown)
219+
""")

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from typing import List, Optional
33
from triton.language.core import _unwrap_if_constexpr
44

5-
__all__ = ["BlockedLayout", "SliceLayout", "NVMMASharedLayout"]
5+
__all__ = [
6+
"BlockedLayout",
7+
"SliceLayout",
8+
"NVMMASharedLayout",
9+
"SwizzledSharedLayout",
10+
]
611

712

813
def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
@@ -123,3 +128,37 @@ def _to_ir(self, builder):
123128

124129
def mangle(self) -> str:
125130
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
131+
132+
133+
@dataclass(frozen=True, eq=True)
134+
class SwizzledSharedLayout(SharedLayout):
135+
vec: int
136+
per_phase: int
137+
max_phase: int
138+
order: List[int]
139+
ctas_per_cga: Optional[List[int]] = None
140+
cta_split_num: Optional[List[int]] = None
141+
cta_order: Optional[List[int]] = None
142+
143+
def __post_init__(self):
144+
rank = len(self.order)
145+
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
146+
assert self.cta_split_num is None or len(self.cta_split_num) == rank
147+
assert self.cta_order is None or len(self.cta_order) == rank
148+
149+
def _to_ir(self, builder):
150+
rank = len(self.order)
151+
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
152+
self.cta_order)
153+
return builder.get_swizzled_shared_layout(
154+
_unwrap_if_constexpr(self.vec),
155+
_unwrap_if_constexpr(self.per_phase),
156+
_unwrap_if_constexpr(self.max_phase),
157+
self.order,
158+
ctas_per_cga,
159+
cta_split_num,
160+
cta_order,
161+
)
162+
163+
def mangle(self) -> str:
164+
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import blackwell
2+
from . import hopper
23

3-
__all__ = ["blackwell"]
4+
__all__ = ["blackwell", "hopper"]

python/triton/experimental/gluon/language/nvidia/blackwell.py renamed to python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55
from triton.experimental.gluon.language import _core as ttgl
66
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
77

8+
from ..hopper import mbarrier
9+
810
if TYPE_CHECKING:
911
from triton._C.libtriton.gluon_ir import GluonOpBuilder
1012
from triton._C.libtriton import gluon_ir as ir
1113

12-
__all__ = ["TensorMemoryLayout", "tensor_memory_descriptor", "allocate_tensor_memory"]
14+
__all__ = [
15+
"TensorMemoryLayout",
16+
"tensor_memory_descriptor",
17+
"allocate_tensor_memory",
18+
"mbarrier",
19+
]
1320

1421

1522
@dataclass(frozen=True, eq=True)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import mbarrier
2+
3+
__all__ = ["mbarrier"]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from triton.experimental.gluon.language._layouts import SwizzledSharedLayout
2+
import triton.experimental.gluon.language._core as ttgl
3+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
4+
5+
__all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"]
6+
7+
8+
class MBarrierLayout(SwizzledSharedLayout):
9+
10+
def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1):
11+
super().__init__(
12+
vec=1,
13+
per_phase=1,
14+
max_phase=1,
15+
order=[0],
16+
ctas_per_cga=[ctas_per_cga],
17+
cta_split_num=[cta_split_num],
18+
cta_order=[0],
19+
)
20+
21+
22+
@builtin
23+
def init(mbarrier, count, _builder=None):
24+
count = _unwrap_if_constexpr(count)
25+
_builder.create_mbarrier_init(mbarrier.handle, count)
26+
27+
28+
@builtin
29+
def invalidate(mbarrier, _builder=None):
30+
_builder.create_mbarrier_inval(mbarrier.handle)
31+
32+
33+
@builtin
34+
def expect(mbarrier, bytes, pred=True, _builder=None):
35+
bytes = _unwrap_if_constexpr(bytes)
36+
pred = ttgl.to_tensor(pred, _builder=_builder)
37+
_builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle)
38+
39+
40+
@builtin
41+
def wait(mbarrier, phase, pred=True, deps=(), _builder=None):
42+
phase = ttgl.to_tensor(phase, _builder=_builder)
43+
pred = ttgl.to_tensor(pred, _builder=_builder)
44+
deps = [x.handle for x in deps]
45+
_builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps)
46+
47+
48+
@builtin
49+
def arrive(mbarrier, count, pred=True, _builder=None):
50+
count = _unwrap_if_constexpr(count)
51+
pred = ttgl.to_tensor(pred, _builder=_builder)
52+
_builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle)

0 commit comments

Comments
 (0)