@@ -638,6 +638,38 @@ def expected() -> None:
638638 tvm .ir .assert_structural_equal (func , expected )
639639
640640
641+ def test_tir_macro_block_name_suffix ():
642+ @T .macro
643+ def operation (A , idx ):
644+ with T .block ("op" ):
645+ v = T .axis .remap ("S" , [idx ])
646+ A [v ] = A [v ] * T .float32 (2 )
647+
648+ @T .prim_func (private = True )
649+ def func_w_macro (a : T .handle ) -> None :
650+ A = T .match_buffer (a , [10 ])
651+ for i in T .serial (0 , 10 ):
652+ operation (A , i )
653+ operation (A , i )
654+ operation (A , i )
655+
656+ @T .prim_func (private = True )
657+ def expected (a : T .handle ) -> None :
658+ A = T .match_buffer (a , [10 ])
659+ for i in T .serial (0 , 10 ):
660+ with T .block ("op" ):
661+ v = T .axis .remap ("S" , [i ])
662+ A [v ] = A [v ] * T .float32 (2 )
663+ with T .block ("op_1" ):
664+ v = T .axis .remap ("S" , [i ])
665+ A [v ] = A [v ] * T .float32 (2 )
666+ with T .block ("op_2" ):
667+ v = T .axis .remap ("S" , [i ])
668+ A [v ] = A [v ] * T .float32 (2 )
669+
670+ tvm .ir .assert_structural_equal (func_w_macro , expected )
671+
672+
641673def test_ifexp ():
642674 @T .prim_func (private = True )
643675 def func (A : T .buffer ((128 , 128 ), "float32" )):
0 commit comments