diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 3b84e919c8bd..cc285e9835de 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -638,6 +638,38 @@ def expected() -> None: tvm.ir.assert_structural_equal(func, expected) +def test_tir_macro_block_name_suffix(): + @T.macro + def operation(A, idx): + with T.block("op"): + v = T.axis.remap("S", [idx]) + A[v] = A[v] * T.float32(2) + + @T.prim_func(private=True) + def func_w_macro(a: T.handle) -> None: + A = T.match_buffer(a, [10]) + for i in T.serial(0, 10): + operation(A, i) + operation(A, i) + operation(A, i) + + @T.prim_func(private=True) + def expected(a: T.handle) -> None: + A = T.match_buffer(a, [10]) + for i in T.serial(0, 10): + with T.block("op"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + with T.block("op_1"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + with T.block("op_2"): + v = T.axis.remap("S", [i]) + A[v] = A[v] * T.float32(2) + + tvm.ir.assert_structural_equal(func_w_macro, expected) + + def test_ifexp(): @T.prim_func(private=True) def func(A: T.buffer((128, 128), "float32")):