Skip to content

Commit b64e85b

Browse files
authored
[Gluon] Add NVMMASharedLayout constructor with default swizzle choice (#7534)
This mirrors the attribute builder here: https://github.com/triton-lang/triton/blob/1031dc78060fc5f63c3fbcdd04d495d2428bc862/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td#L480
1 parent 9164c06 commit b64e85b

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ def type(self):
233233
return constexpr_type(self)
234234

235235

236+
def _get_shape_per_cta(shape, cta_split_num):
237+
shape_per_cta = shape
238+
if cta_split_num is not None:
239+
assert len(cta_split_num) == len(shape)
240+
for dim in range(len(shape_per_cta)):
241+
shape_per_cta[dim] /= cta_split_num[dim]
242+
return shape_per_cta
243+
244+
236245
@dataclass(frozen=True)
237246
class NVMMASharedLayout(SharedLayout):
238247
"""
@@ -286,6 +295,47 @@ def _to_ir(self, builder):
286295
self.cta_order,
287296
)
288297

298+
@staticmethod
299+
def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None,
300+
cta_order=None):
301+
"""Returns an NVMMASharedLayout with default swizzling for a given shape.
302+
303+
This picks the largest swizzle pattern compatible with the shape, which
304+
allows emitting the fewest TMA or MMA messages.
305+
"""
306+
packing_factor = 2 if fp4_padded else 1
307+
shape_per_cta = _get_shape_per_cta(block_shape, cta_split_num)
308+
rank = len(block_shape)
309+
if transposed:
310+
shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1]
311+
contig_dim_size = shape_per_cta[-1] * packing_factor
312+
contig_dim_bytes = contig_dim_size * dtype.primitive_bitwidth // 8
313+
if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0:
314+
swizzle_byte_width = 128
315+
elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0:
316+
swizzle_byte_width = 64
317+
elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0:
318+
swizzle_byte_width = 32
319+
else:
320+
swizzle_byte_width = 0
321+
322+
flatten_outer_dim = 1
323+
for size in shape_per_cta[:-1]:
324+
flatten_outer_dim *= size
325+
if len(block_shape) < 2 or flatten_outer_dim < 8:
326+
swizzle_byte_width = 0
327+
328+
return NVMMASharedLayout(
329+
swizzle_byte_width=swizzle_byte_width,
330+
element_bitwidth=dtype.primitive_bitwidth,
331+
rank=rank,
332+
transposed=transposed,
333+
fp4_padded=fp4_padded,
334+
ctas_per_cga=ctas_per_cga,
335+
cta_split_num=cta_split_num,
336+
cta_order=cta_order,
337+
)
338+
289339
def mangle(self) -> str:
290340
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
291341

python/tutorials/gluon/01-attention-forward.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,36 +68,6 @@ def get_mma_instr_shape(shape, element_ty):
6868
return (m, n, k)
6969

7070

71-
@gl.constexpr_function
72-
def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
73-
packing_factor = 2 if fp4_padded else 1
74-
75-
contig_dim_size = shape[order[0]] * packing_factor * element_ty.primitive_bitwidth // 8
76-
if contig_dim_size >= 128 and contig_dim_size % 128 == 0:
77-
swizzle_byte_width = 128
78-
elif contig_dim_size >= 64 and contig_dim_size % 64 == 0:
79-
swizzle_byte_width = 64
80-
elif contig_dim_size >= 32 and contig_dim_size % 32 == 0:
81-
swizzle_byte_width = 32
82-
else:
83-
swizzle_byte_width = 0
84-
85-
flatten_outer_dim = 1
86-
for i in range(1, len(shape)):
87-
flatten_outer_dim *= shape[order[i]]
88-
if len(shape) < 2 or flatten_outer_dim < 8:
89-
swizzle_byte_width = 0
90-
transposed = order[0] == 0
91-
92-
return gl.NVMMASharedLayout(
93-
swizzle_byte_width=swizzle_byte_width,
94-
element_bitwidth=element_ty.primitive_bitwidth,
95-
rank=len(shape),
96-
transposed=transposed,
97-
fp4_padded=fp4_padded,
98-
)
99-
100-
10171
@gl.constexpr_function
10272
def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
10373
instr_shape = get_mma_instr_shape(shape, dtype)
@@ -995,8 +965,8 @@ def torch_dtype_to_triton(dtype):
995965

996966

997967
def make_tensor_desc(x, shape, strides, block_shape):
998-
layout = get_nvmma_layout(block_shape, torch_dtype_to_triton(x.dtype))
999-
return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout.value)
968+
layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(x.dtype))
969+
return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout)
1000970

1001971

1002972
def attention_forward(q, k, v, causal, sm_scale):

0 commit comments

Comments
 (0)