@@ -233,6 +233,15 @@ def type(self):
233
233
return constexpr_type (self )
234
234
235
235
236
+ def _get_shape_per_cta (shape , cta_split_num ):
237
+ shape_per_cta = shape
238
+ if cta_split_num is not None :
239
+ assert len (cta_split_num ) == len (shape )
240
+ for dim in range (len (shape_per_cta )):
241
+ shape_per_cta [dim ] /= cta_split_num [dim ]
242
+ return shape_per_cta
243
+
244
+
236
245
@dataclass (frozen = True )
237
246
class NVMMASharedLayout (SharedLayout ):
238
247
"""
@@ -286,6 +295,47 @@ def _to_ir(self, builder):
286
295
self .cta_order ,
287
296
)
288
297
298
+ @staticmethod
299
+ def get_default_for (block_shape , dtype , transposed = False , fp4_padded = False , ctas_per_cga = None , cta_split_num = None ,
300
+ cta_order = None ):
301
+ """Returns an NVMMASharedLayout with default swizzling for a given shape.
302
+
303
+ This picks the largest swizzle pattern compatible with the shape, which
304
+ allows emitting the fewest TMA or MMA messages.
305
+ """
306
+ packing_factor = 2 if fp4_padded else 1
307
+ shape_per_cta = _get_shape_per_cta (block_shape , cta_split_num )
308
+ rank = len (block_shape )
309
+ if transposed :
310
+ shape_per_cta = shape_per_cta [1 :] + shape_per_cta [:1 ]
311
+ contig_dim_size = shape_per_cta [- 1 ] * packing_factor
312
+ contig_dim_bytes = contig_dim_size * dtype .primitive_bitwidth // 8
313
+ if contig_dim_bytes >= 128 and contig_dim_bytes % 128 == 0 :
314
+ swizzle_byte_width = 128
315
+ elif contig_dim_bytes >= 64 and contig_dim_bytes % 64 == 0 :
316
+ swizzle_byte_width = 64
317
+ elif contig_dim_bytes >= 32 and contig_dim_bytes % 32 == 0 :
318
+ swizzle_byte_width = 32
319
+ else :
320
+ swizzle_byte_width = 0
321
+
322
+ flatten_outer_dim = 1
323
+ for size in shape_per_cta [:- 1 ]:
324
+ flatten_outer_dim *= size
325
+ if len (block_shape ) < 2 or flatten_outer_dim < 8 :
326
+ swizzle_byte_width = 0
327
+
328
+ return NVMMASharedLayout (
329
+ swizzle_byte_width = swizzle_byte_width ,
330
+ element_bitwidth = dtype .primitive_bitwidth ,
331
+ rank = rank ,
332
+ transposed = transposed ,
333
+ fp4_padded = fp4_padded ,
334
+ ctas_per_cga = ctas_per_cga ,
335
+ cta_split_num = cta_split_num ,
336
+ cta_order = cta_order ,
337
+ )
338
+
289
339
def mangle (self ) -> str :
290
340
return f"NVMMA_{ self .swizzle_byte_width } _{ self .element_bitwidth } _{ self .transposed } _{ self .fp4_padded } _NVMMA"
291
341
0 commit comments