Skip to content

Commit ec7f59f

Browse files
authored
[TVMScript] Add test for TIR macro block name suffix handling (#18504)
## How add missing tests for #18465
1 parent 161049e commit ec7f59f

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,38 @@ def expected() -> None:
638638
tvm.ir.assert_structural_equal(func, expected)
639639

640640

641+
def test_tir_macro_block_name_suffix():
642+
@T.macro
643+
def operation(A, idx):
644+
with T.block("op"):
645+
v = T.axis.remap("S", [idx])
646+
A[v] = A[v] * T.float32(2)
647+
648+
@T.prim_func(private=True)
649+
def func_w_macro(a: T.handle) -> None:
650+
A = T.match_buffer(a, [10])
651+
for i in T.serial(0, 10):
652+
operation(A, i)
653+
operation(A, i)
654+
operation(A, i)
655+
656+
@T.prim_func(private=True)
657+
def expected(a: T.handle) -> None:
658+
A = T.match_buffer(a, [10])
659+
for i in T.serial(0, 10):
660+
with T.block("op"):
661+
v = T.axis.remap("S", [i])
662+
A[v] = A[v] * T.float32(2)
663+
with T.block("op_1"):
664+
v = T.axis.remap("S", [i])
665+
A[v] = A[v] * T.float32(2)
666+
with T.block("op_2"):
667+
v = T.axis.remap("S", [i])
668+
A[v] = A[v] * T.float32(2)
669+
670+
tvm.ir.assert_structural_equal(func_w_macro, expected)
671+
672+
641673
def test_ifexp():
642674
@T.prim_func(private=True)
643675
def func(A: T.buffer((128, 128), "float32")):

0 commit comments

Comments
 (0)