Skip to content

Commit 40c9b1c

Browse files
authored
[Gluon] Fix linear layout MLIR->Python; fix CTA layout equality (#7230)
This PR makes layouts always materialize their CTA layouts so that `BlockedLayout([1], [32], [4], [0]) == BlockedLayout([1], [32], [4], [0], [1], [1], [0])`. This is important especially since layouts raised from MLIR always have CTA layouts attached.
1 parent 3c893cf commit 40c9b1c

File tree

4 files changed

+51
-45
lines changed

4 files changed

+51
-45
lines changed

python/src/gluon_ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ py::object layoutToGluon(Attribute layout) {
130130
return layouts.DistributedLinearLayout(
131131
ll.getBases().lookup(kReg), ll.getBases().lookup(kLane),
132132
ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock),
133-
ll.getOutDimSizes());
133+
toStdVector(ArrayRef(llvm::to_vector(ll.getOutDimSizes()))));
134134
} else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135135
auto ctaLayout = nvmma.getCTALayout();
136136
return layouts.NVMMASharedLayout(

python/test/gluon/test_frontend.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ def shared_memory_cast_kernel():
246246
layout_a: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=8,
247247
rank=2)
248248
layout_T: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=True, element_bitwidth=8,
249-
rank=2, ctas_per_cga=[1, 1], cta_split_num=[1,
250-
1], cta_order=[1, 0])
249+
rank=2)
251250
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
252251
perm = smem.index(0).permute((1, 0))
253252
ttgl.static_assert(perm.type.layout == layout_T)
@@ -613,10 +612,10 @@ def kernel():
613612
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
614613
tt.func public @kernel() attributes {noinline = false} {
615614
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
616-
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
615+
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
617616
tt.return
618617
}
619-
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
618+
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
620619
tt.return
621620
}
622621
}
@@ -855,7 +854,7 @@ def test_tensor_permute():
855854
a = ttgl.full([32, 16], 0, ttgl.int32, layout=layout)
856855
# CHECK: tt.trans{{.*}} : tensor<32x16xi32, [[BLOCKED]]> -> tensor<16x32xi32, [[BLOCKED1]]>
857856
res = ttgl.permute(a, [1, 0])
858-
permuted_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 1], [8, 4], [1, 4], [0, 1], [1, 1], [1, 1], [1, 0])
857+
permuted_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 1], [8, 4], [1, 4], [0, 1])
859858
ttgl.static_assert(permuted_layout == res.type.layout)
860859

861860

@@ -869,7 +868,7 @@ def test_split_join():
869868
b = ttgl.full([128], 2, ttgl.int32, layout)
870869
# CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
871870
res = ttgl.join(a, b)
872-
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 2], [32, 1], [4, 1], [1, 0], [1, 1], [1, 1], [1, 0])
871+
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 2], [32, 1], [4, 1], [1, 0])
873872
ttgl.static_assert(res.type.layout == expect_layout)
874873

875874
# CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, [[BLOCKED]]>
@@ -878,6 +877,17 @@ def test_split_join():
878877
ttgl.static_assert(d.type.layout == layout)
879878

880879

880+
@filecheck_test
881+
@gluon.jit
882+
def test_reshape_linear_layout():
883+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
884+
# CHECK: [[LINEAR:#.*]] = #ttg.linear
885+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1])
886+
x = ttgl.full([128, 1], 1, ttgl.int32, layout=layout)
887+
# CHECK: tt.reshape %{{.*}} : tensor<128x1xi32, [[BLOCKED]]> -> tensor<128xi32, [[LINEAR]]>
888+
x.reshape([128])
889+
890+
881891
@filecheck_test
882892
@gluon.jit
883893
def test_tensor_reshape():
@@ -887,8 +897,7 @@ def test_tensor_reshape():
887897
a = ttgl.full([256], 1, ttgl.int32, layout)
888898
# CHECK: tt.reshape {{.*}} : tensor<256xi32, [[BLOCKED]]> -> tensor<8x4x8xi32, [[BLOCKED1]]>
889899
v = a.reshape([8, 4, 8])
890-
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 2], [2, 4, 4], [4, 1, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1],
891-
[2, 1, 0])
900+
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 2], [2, 4, 4], [4, 1, 1], [2, 1, 0])
892901
ttgl.static_assert(v.type.layout == expect_layout)
893902

894903

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
]
1212

1313

14-
def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order):
15-
ctas_per_cga = ctas_per_cga or [1] * rank
16-
cta_split_num = cta_split_num or [1] * rank
17-
cta_order = cta_order or list(reversed(range(rank)))
18-
return ctas_per_cga, cta_split_num, cta_order
14+
def _realize_cta_layout(layout, rank):
15+
ctas_per_cga = layout.ctas_per_cga or [1] * rank
16+
cta_split_num = layout.cta_split_num or [1] * rank
17+
cta_order = layout.cta_order or list(reversed(range(rank)))
18+
object.__setattr__(layout, "ctas_per_cga", ctas_per_cga)
19+
object.__setattr__(layout, "cta_split_num", cta_split_num)
20+
object.__setattr__(layout, "cta_order", cta_order)
1921

2022

2123
class DistributedLayout:
@@ -42,25 +44,23 @@ def __post_init__(self):
4244
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
4345

4446
rank = len(self.size_per_thread)
47+
_realize_cta_layout(self, rank)
4548
assert len(self.threads_per_warp) == rank
4649
assert len(self.warps_per_cta) == rank
4750
assert len(self.order) == rank
48-
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
49-
assert self.cta_split_num is None or len(self.cta_split_num) == rank
50-
assert self.cta_order is None or len(self.cta_order) == rank
51+
assert len(self.ctas_per_cga) == rank
52+
assert len(self.cta_split_num) == rank
53+
assert len(self.cta_order) == rank
5154

5255
def _to_ir(self, builder):
53-
rank = len(self.size_per_thread)
54-
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
55-
self.cta_order)
5656
return builder.get_blocked_layout(
5757
self.size_per_thread,
5858
self.threads_per_warp,
5959
self.warps_per_cta,
6060
self.order,
61-
ctas_per_cga,
62-
cta_split_num,
63-
cta_order,
61+
self.ctas_per_cga,
62+
self.cta_split_num,
63+
self.cta_order,
6464
)
6565

6666
def mangle(self) -> str:
@@ -161,21 +161,20 @@ def __post_init__(self):
161161
assert self.element_bitwidth in [8, 16, 32, 64]
162162
assert self.swizzle_byte_width in [0, 32, 64, 128]
163163
rank = self.rank
164-
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
165-
assert self.cta_split_num is None or len(self.cta_split_num) == rank
166-
assert self.cta_order is None or len(self.cta_order) == rank
164+
_realize_cta_layout(self, rank)
165+
assert len(self.ctas_per_cga) == rank
166+
assert len(self.cta_split_num) == rank
167+
assert len(self.cta_order) == rank
167168

168169
def _to_ir(self, builder):
169-
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num,
170-
self.cta_order)
171170
return builder.get_nvmma_shared_layout(
172171
self.swizzle_byte_width,
173172
self.element_bitwidth,
174173
self.transposed,
175174
self.fp4_padded,
176-
ctas_per_cga,
177-
cta_split_num,
178-
cta_order,
175+
self.ctas_per_cga,
176+
self.cta_split_num,
177+
self.cta_order,
179178
)
180179

181180
def mangle(self) -> str:
@@ -202,22 +201,20 @@ def __post_init__(self):
202201
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
203202

204203
rank = len(self.order)
205-
assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank
206-
assert self.cta_split_num is None or len(self.cta_split_num) == rank
207-
assert self.cta_order is None or len(self.cta_order) == rank
204+
_realize_cta_layout(self, rank)
205+
assert len(self.ctas_per_cga) == rank
206+
assert len(self.cta_split_num) == rank
207+
assert len(self.cta_order) == rank
208208

209209
def _to_ir(self, builder):
210-
rank = len(self.order)
211-
ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num,
212-
self.cta_order)
213210
return builder.get_swizzled_shared_layout(
214-
_unwrap_if_constexpr(self.vec),
215-
_unwrap_if_constexpr(self.per_phase),
216-
_unwrap_if_constexpr(self.max_phase),
211+
self.vec,
212+
self.per_phase,
213+
self.max_phase,
217214
self.order,
218-
ctas_per_cga,
219-
cta_split_num,
220-
cta_order,
215+
self.ctas_per_cga,
216+
self.cta_split_num,
217+
self.cta_order,
221218
)
222219

223220
def mangle(self) -> str:

python/triton/language/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ def str_to_ty(name):
287287

288288
if name.startswith("tensordesc"):
289289
inner = name.split("<")[1].rstrip(">")
290-
dtype, rest = inner.split("[", maxsplit=2)
291-
block_shape, rest = rest.split("]", maxsplit=2)
290+
dtype, rest = inner.split("[", maxsplit=1)
291+
block_shape, rest = rest.split("]", maxsplit=1)
292292
block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
293293
layout = rest.lstrip(",")
294294
is_gluon = len(layout)

0 commit comments

Comments
 (0)