Skip to content

Commit c2cfe83

Browse files
authored
[Gluon] Expose DotOperandLayout (#7730)
Needed for WGMMA with LHS in registers
1 parent 6578b58 commit c2cfe83

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

python/src/gluon_ir.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ struct GluonLayouts {
9090
py::handle BlockedLayout;
9191
py::handle SliceLayout;
9292
py::handle DistributedLinearLayout;
93+
py::handle DotOperandLayout;
9394
py::handle NVMMADistributedLayout;
9495
py::handle NVMMASharedLayout;
9596
py::handle SwizzledSharedLayout;
@@ -102,6 +103,7 @@ struct GluonLayouts {
102103
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
103104
DistributedLinearLayout =
104105
py::object(layouts.attr("DistributedLinearLayout")).release();
106+
DotOperandLayout = py::object(layouts.attr("DotOperandLayout")).release();
105107
NVMMADistributedLayout =
106108
py::object(layouts.attr("NVMMADistributedLayout")).release();
107109
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
@@ -155,6 +157,9 @@ py::object layoutToGluon(Attribute layout) {
155157
ll.getBases().lookup(kReg), ll.getBases().lookup(kLane),
156158
ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock),
157159
toStdVector(ArrayRef(llvm::to_vector(ll.getOutDimSizes()))));
160+
} else if (auto dotOp = dyn_cast<ttg::DotOperandEncodingAttr>(layout)) {
161+
return layouts.DotOperandLayout(
162+
dotOp.getOpIdx(), layoutToGluon(dotOp.getParent()), dotOp.getKWidth());
158163
} else if (auto mma = dyn_cast<ttg::NvidiaMmaEncodingAttr>(layout)) {
159164
auto ctaLayout = mma.getCTALayout();
160165
return layouts.NVMMADistributedLayout(
@@ -259,6 +264,12 @@ void init_gluon_ir(py::module &&m) {
259264
/*requiresSurjective=*/true);
260265
return ttg::LinearEncodingAttr::get(ctx, ll);
261266
})
267+
.def("get_dot_operand_layout",
268+
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,
269+
unsigned kWidth) -> Attribute {
270+
return self.getChecked<ttg::DotOperandEncodingAttr>(
271+
self.getContext(), opIdx, parent, kWidth);
272+
})
262273
.def("get_mma_layout",
263274
[](GluonOpBuilder &self, std::vector<unsigned> &version,
264275
std::vector<unsigned> &warpsPerCta,

python/test/gluon/test_frontend.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_
119119
layout_b: ttgl.constexpr, smem_layout: ttgl.constexpr):
120120
unused = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK, YBLOCK], smem_layout)
121121
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout_a)
122-
tl.static_assert(a.numel == unused.numel)
123-
tl.static_assert(unused.numel == XBLOCK * YBLOCK)
122+
ttgl.static_assert(a.numel == unused.numel)
123+
ttgl.static_assert(unused.numel == XBLOCK * YBLOCK)
124124
mem = ttgl.allocate_shared_memory(ttgl.int32, a.shape, smem_layout, a)
125125
b = mem.load(layout_b) # noqa: F841
126126
mem.store(a)
@@ -641,7 +641,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
641641
mbarrier.init(bar, count=1)
642642

643643
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
644-
tl.static_assert(input_desc.block_type.nbytes == XBLOCK * XBLOCK * 2)
644+
ttgl.static_assert(input_desc.block_type.nbytes == XBLOCK * XBLOCK * 2)
645645
mbarrier.expect(bar, input_desc.block_type.nbytes)
646646
mbarrier.wait(bar, 0)
647647

@@ -941,7 +941,7 @@ def reduce_kernel(out):
941941
ttgl.static_assert(pairs[0].type.layout == ttgl.SliceLayout(0, layout))
942942
ttgl.static_assert(pairs[1].type.layout == ttgl.SliceLayout(0, layout))
943943
result = scalar + s1 + pairs[0] + pairs[1]
944-
tl.store(out + ttgl.arange(0, 16, s0.type.layout), result)
944+
ttgl.store(out + ttgl.arange(0, 16, s0.type.layout), result)
945945

946946

947947
@pytest.mark.parametrize("target", ALL_TARGETS)
@@ -1057,8 +1057,8 @@ def test_elementwise_core():
10571057

10581058
@gluon.jit
10591059
def linear_layout_kernel():
1060-
ll: tl.constexpr = ttgl.DistributedLinearLayout(reg_bases=[[1]], lane_bases=[[2], [4], [8], [16], [32]],
1061-
warp_bases=[[64], [128]], block_bases=[], shape=[256])
1060+
ll: ttgl.constexpr = ttgl.DistributedLinearLayout(reg_bases=[[1]], lane_bases=[[2], [4], [8], [16], [32]],
1061+
warp_bases=[[64], [128]], block_bases=[], shape=[256])
10621062
ttgl.arange(0, 256, layout=ll)
10631063

10641064

@@ -1077,6 +1077,20 @@ def test_linear_layout(target):
10771077
""")
10781078

10791079

1080+
@filecheck_test
1081+
@gluon.jit
1082+
def test_dot_operand_layout():
1083+
# CHECK: [[NVMMA:#.*]] = #ttg.nvidia_mma
1084+
# CHECK: test_dot_operand_layout
1085+
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1],
1086+
instr_shape=[16, 32, 16])
1087+
layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mma_layout, k_width=2)
1088+
# CHECK: arith.constant {{.*}} tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[NVMMA]], kWidth = 2}>>
1089+
x = ttgl.full([256, 128], 0.0, ttgl.float16, layout)
1090+
y = x.sum(axis=1)
1091+
ttgl.static_assert(y.type.layout.parent == layout)
1092+
1093+
10801094
@filecheck_test
10811095
@gluon.jit
10821096
def test_tensor_permute():
@@ -1201,7 +1215,7 @@ def async_copy_kernel(inp, xnumel, XBLOCK: ttgl.constexpr):
12011215
smem = ttgl.allocate_shared_memory(inp.dtype.element_ty, [XBLOCK], ttgl.SwizzledSharedLayout(1, 1, 1, order=[0]))
12021216
block_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
12031217
xindex = ttgl.arange(0, XBLOCK, block_layout)
1204-
mask = tl.max_constancy(xindex < xnumel, 2)
1218+
mask = ttgl.max_constancy(xindex < xnumel, 2)
12051219

12061220
async_copy.async_copy_global_to_shared(smem, inp + xindex)
12071221
async_copy.async_copy_global_to_shared(smem, inp + xindex, mask, cache_modifier=".ca", eviction_policy="evict_last",

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"BlockedLayout",
88
"SliceLayout",
99
"DistributedLinearLayout",
10+
"DotOperandLayout",
1011
"NVMMADistributedLayout",
1112
"NVMMASharedLayout",
1213
"SwizzledSharedLayout",
@@ -181,6 +182,32 @@ def mangle(self):
181182
return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
182183

183184

185+
@dataclass(frozen=True)
186+
class DotOperandLayout(DistributedLayout):
187+
"""
188+
Represents a layout for a dot operand.
189+
190+
Args:
191+
operand_index (int): 0 for LHS and 1 for RHS of the dot operation.
192+
parent (DistributedLayout): The parent layout, representing the MMA.
193+
k_width (int): Number of elements per 32-bits.
194+
"""
195+
operand_index: int
196+
parent: DistributedLayout
197+
k_width: int
198+
199+
def __post_init__(self):
200+
super().__setattr__("operand_index", _unwrap_if_constexpr(self.operand_index))
201+
super().__setattr__("parent", _unwrap_if_constexpr(self.parent))
202+
super().__setattr__("k_width", _unwrap_if_constexpr(self.k_width))
203+
204+
def _to_ir(self, builder):
205+
return builder.get_dot_operand_layout(self.operand_index, self.parent._to_ir(builder), self.k_width)
206+
207+
def mangle(self) -> str:
208+
return f"DO{self.operand_index}_{self.parent.mangle()}_{self.k_width}DO"
209+
210+
184211
@dataclass(frozen=True)
185212
class NVMMADistributedLayout(DistributedLayout):
186213
"""

0 commit comments

Comments
 (0)