Skip to content

Commit ced7181

Browse files
authored
[TVMScript] Add block name suffix management for TIR macros (#18465)
## Related Issue closes #18344 ## Why When a `T.macro` containing a block was called multiple times in a TIR function, all expanded blocks had the same name, causing a "Duplicated block name" error in meta_schedule. ## How Implemented automatic block name suffixing during macro expansion
1 parent 9826150 commit ced7181

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# under the License.
1717
"""IRBuilder for TIR"""
1818

19+
import contextlib
1920
import functools
2021
import inspect
2122
import sys
23+
import threading
2224
from numbers import Integral
2325
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2426

@@ -87,6 +89,35 @@
8789
# pylint: enable=unused-import
8890

8991

92+
_block_name_suffix = threading.local()
93+
94+
95+
def _get_block_name_suffix() -> str:
96+
"""Get the current block name suffix for macro expansion."""
97+
return getattr(_block_name_suffix, "value", "")
98+
99+
100+
@contextlib.contextmanager
101+
def block_name_suffix_context(block_suffix: str):
102+
"""Context manager to set block name suffix during macro expansion.
103+
104+
Parameters
105+
----------
106+
block_suffix : str
107+
The suffix to append to block names (e.g., "_1", "_2").
108+
109+
Yields
110+
------
111+
None
112+
"""
113+
old_suffix = getattr(_block_name_suffix, "value", "")
114+
_block_name_suffix.value = block_suffix
115+
try:
116+
yield
117+
finally:
118+
_block_name_suffix.value = old_suffix
119+
120+
90121
def buffer(
91122
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
92123
dtype: str = "float32",
@@ -352,6 +383,9 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
352383
res : frame.BlockFrame
353384
The BlockFrame.
354385
"""
386+
block_suffix = _get_block_name_suffix()
387+
if block_suffix and name:
388+
name = name + block_suffix
355389
return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member
356390

357391

@@ -2135,6 +2169,7 @@ def wrapped(*args, **kwargs):
21352169
"func_ret",
21362170
"match_buffer",
21372171
"block",
2172+
"block_name_suffix_context",
21382173
"init",
21392174
"where",
21402175
"reads",

python/tvm/script/parser/tir/entry.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

24-
from ...ir_builder.tir import buffer, ptr
24+
from ...ir_builder.tir import block_name_suffix_context, buffer, ptr
2525
from .._core import parse, scan_macro, utils
2626
from ..core.parser import Parser, ScriptMacro
2727

@@ -90,11 +90,25 @@ def decorator_wrapper(func):
9090

9191

9292
class TIRMacro(ScriptMacro):
93-
"""Specialization of the ScriptMacro class for TIR."""
93+
"""Specialization of the ScriptMacro class for TIR.
94+
95+
Attributes
96+
----------
97+
call_count : int
98+
Counter for the number of times this macro has been invoked.
99+
Used to generate unique block name suffixes.
100+
"""
101+
102+
def __init__(self, *args, **kwargs):
103+
super().__init__(*args, **kwargs)
104+
self.call_count = 0
94105

95106
def parse_macro(self, parser: Parser) -> None:
96107
macro_def = self.get_macro_def()
97-
parser.visit_body(macro_def.body)
108+
suffix = f"_{self.call_count}" if self.call_count > 0 else ""
109+
self.call_count += 1
110+
with block_name_suffix_context(suffix):
111+
parser.visit_body(macro_def.body)
98112

99113

100114
def macro(*args, hygienic: bool = True) -> Callable:

0 commit comments

Comments
 (0)