@@ -346,6 +346,78 @@ def body(smem_ref, value_to_store):
346346 with self .assertRaisesRegex (NotImplementedError , "Conflicting transforms" ):
347347 mgpu .infer_transforms (self .module )
348348
349+ def test_infer_transforms_for_slice_smem_op_derives_from_user (self ):
350+ slice_smem_op = vector_load_op = None
351+ shape = (64 , 64 )
352+ elt_ty = ir .BF16Type .get ()
353+ smem = ir .Attribute .parse ("#gpu.address_space<workgroup>" )
354+
355+ def body (offset ):
356+ nonlocal slice_smem_op , vector_load_op
357+ slice_smem_op = mgpu .dialect .SliceSMEMOp (
358+ ir .MemRefType .get (shape , elt_ty , memory_space = smem ), offset
359+ )
360+ zero = arith .constant (ir .IntegerType .get_signless (32 ), 0 )
361+ load_offsets = [zero ] * len (shape )
362+ vector_load_op = vector .LoadOp (
363+ ir .VectorType .get (shape , elt_ty ), slice_smem_op .result , load_offsets
364+ )
365+
366+ with ir .InsertionPoint (self .module .body ):
367+ func .FuncOp .from_py_func (ir .IntegerType .get_signless (32 ))(body )
368+
369+ vector_load_op .attributes ["out_layouts" ] = ir .ArrayAttr .get (
370+ [layouts_lib .to_layout_attr (fa .WGMMA_LAYOUT )]
371+ )
372+
373+ mgpu .infer_transforms (self .module )
374+
375+ expected_transforms = ir .ArrayAttr .get ([
376+ mgpu .dialect .TileTransformAttr .get ((8 , 64 )),
377+ mgpu .dialect .SwizzleTransformAttr .get (128 ),
378+ ])
379+
380+ self .assertEmpty (inference_utils .in_transforms (slice_smem_op ))
381+ self .assertSequenceEqual (
382+ inference_utils .out_transforms (slice_smem_op ), [expected_transforms ]
383+ )
384+
385+ def test_infer_transforms_for_slice_smem_op_raises_on_mismatches (self ):
386+ slice_smem_op = vector_load_op1 = vector_load_op2 = None
387+ shape = (64 , 64 )
388+ elt_ty = ir .BF16Type .get ()
389+ smem = ir .Attribute .parse ("#gpu.address_space<workgroup>" )
390+
391+ def body (offset ):
392+ nonlocal slice_smem_op , vector_load_op1 , vector_load_op2
393+ slice_smem_op = mgpu .dialect .SliceSMEMOp (
394+ ir .MemRefType .get (shape , elt_ty , memory_space = smem ), offset
395+ )
396+ zero = arith .constant (ir .IntegerType .get_signless (32 ), 0 )
397+ load_offsets = [zero ] * len (shape )
398+ vector_load_op1 = vector .LoadOp (
399+ ir .VectorType .get (shape , elt_ty ), slice_smem_op .result , load_offsets
400+ )
401+ vector_load_op2 = vector .LoadOp (
402+ ir .VectorType .get (shape , elt_ty ), slice_smem_op .result , load_offsets
403+ )
404+
405+ with ir .InsertionPoint (self .module .body ):
406+ func .FuncOp .from_py_func (ir .IntegerType .get_signless (32 ))(body )
407+
408+ vector_load_op1 .attributes ["out_layouts" ] = ir .ArrayAttr .get (
409+ [layouts_lib .to_layout_attr (fa .WGMMA_LAYOUT )]
410+ )
411+ vector_load_op2 .attributes ["out_layouts" ] = ir .ArrayAttr .get (
412+ [layouts_lib .to_layout_attr (fa .WGStridedFragLayout (shape , vec_size = 4 ))]
413+ )
414+ vector_load_op2 .attributes ["in_transforms" ] = ir .ArrayAttr .get (
415+ [ir .ArrayAttr .get ([mgpu .dialect .TransposeTransformAttr .get ((1 , 0 ))])]
416+ )
417+
418+ with self .assertRaisesRegex (NotImplementedError , "Conflicting transforms" ):
419+ mgpu .infer_transforms (self .module )
420+
349421
350422if __name__ == "__main__" :
351423 parameterized .absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments