Skip to content

Commit b965f14

Browse files
committed
Add block name suffix management for TIR macros
1 parent 49973d1 commit b965f14

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# under the License.
1717
"""IRBuilder for TIR"""
1818

19+
import contextlib
1920
import functools
2021
import inspect
2122
import sys
23+
import threading
2224
from numbers import Integral
2325
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2426

@@ -87,6 +89,36 @@
8789
# pylint: enable=unused-import
8890

8991

92+
# Thread-local storage for block name suffix in macro expansion
93+
_block_name_suffix = threading.local()
94+
95+
96+
def _get_block_name_suffix() -> str:
97+
"""Get the current block name suffix for macro expansion."""
98+
return getattr(_block_name_suffix, "value", "")
99+
100+
101+
@contextlib.contextmanager
102+
def block_name_suffix_context(suffix: str):
103+
"""Context manager to set block name suffix during macro expansion.
104+
105+
Parameters
106+
----------
107+
suffix : str
108+
The suffix to append to block names (e.g., "_1", "_2").
109+
110+
Yields
111+
------
112+
None
113+
"""
114+
old_suffix = getattr(_block_name_suffix, "value", "")
115+
_block_name_suffix.value = suffix
116+
try:
117+
yield
118+
finally:
119+
_block_name_suffix.value = old_suffix
120+
121+
90122
def buffer(
91123
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
92124
dtype: str = "float32",
@@ -352,6 +384,10 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
352384
res : frame.BlockFrame
353385
The BlockFrame.
354386
"""
387+
# Apply suffix from macro expansion context if present
388+
suffix = _get_block_name_suffix()
389+
if suffix and name:
390+
name = name + suffix
355391
return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member
356392

357393

@@ -2107,6 +2143,7 @@ def wrapped(*args, **kwargs):
21072143
"func_ret",
21082144
"match_buffer",
21092145
"block",
2146+
"block_name_suffix_context",
21102147
"init",
21112148
"where",
21122149
"reads",

python/tvm/script/parser/tir/entry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

24-
from ...ir_builder.tir import buffer, ptr
24+
from ...ir_builder.tir import block_name_suffix_context, buffer, ptr
2525
from .._core import parse, scan_macro, utils
2626
from ..core.parser import Parser, ScriptMacro
2727

@@ -90,11 +90,27 @@ def decorator_wrapper(func):
9090

9191

9292
class TIRMacro(ScriptMacro):
93-
"""Specialization of the ScriptMacro class for TIR."""
93+
"""Specialization of the ScriptMacro class for TIR.
94+
95+
Attributes
96+
----------
97+
call_count : int
98+
Counter for the number of times this macro has been invoked.
99+
Used to generate unique block name suffixes.
100+
"""
101+
102+
def __init__(self, *args, **kwargs):
103+
super().__init__(*args, **kwargs)
104+
self.call_count = 0
94105

95106
def parse_macro(self, parser: Parser) -> None:
96107
macro_def = self.get_macro_def()
97-
parser.visit_body(macro_def.body)
108+
# Apply block name suffix to avoid duplicate block names
109+
110+
suffix = f"_{self.call_count}" if self.call_count > 0 else ""
111+
self.call_count += 1
112+
with block_name_suffix_context(suffix):
113+
parser.visit_body(macro_def.body)
98114

99115

100116
def macro(*args, hygienic: bool = True) -> Callable:

0 commit comments

Comments
 (0)