diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 6d746d73b1be..a7fbd4897325 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -16,9 +16,11 @@ # under the License. """IRBuilder for TIR""" +import contextlib import functools import inspect import sys +import threading from numbers import Integral from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -87,6 +89,35 @@ # pylint: enable=unused-import +_block_name_suffix = threading.local() + + +def _get_block_name_suffix() -> str: + """Get the current block name suffix for macro expansion.""" + return getattr(_block_name_suffix, "value", "") + + +@contextlib.contextmanager +def block_name_suffix_context(block_suffix: str): + """Context manager to set block name suffix during macro expansion. + + Parameters + ---------- + block_suffix : str + The suffix to append to block names (e.g., "_1", "_2"). + + Yields + ------ + None + """ + old_suffix = getattr(_block_name_suffix, "value", "") + _block_name_suffix.value = block_suffix + try: + yield + finally: + _block_name_suffix.value = old_suffix + + def buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], dtype: str = "float32", @@ -352,6 +383,9 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: res : frame.BlockFrame The BlockFrame. """ + block_suffix = _get_block_name_suffix() + if block_suffix and name: + name = name + block_suffix return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member @@ -2107,6 +2141,7 @@ def wrapped(*args, **kwargs): "func_ret", "match_buffer", "block", + "block_name_suffix_context", "init", "where", "reads", diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index c7d5dc756b32..bcac49733d00 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -21,7 +21,7 @@ from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc -from ...ir_builder.tir import buffer, ptr +from ...ir_builder.tir import block_name_suffix_context, buffer, ptr from .._core import parse, scan_macro, utils from ..core.parser import Parser, ScriptMacro @@ -90,11 +90,25 @@ def decorator_wrapper(func): class TIRMacro(ScriptMacro): - """Specialization of the ScriptMacro class for TIR.""" + """Specialization of the ScriptMacro class for TIR. + + Attributes + ---------- + call_count : int + Counter for the number of times this macro has been invoked. + Used to generate unique block name suffixes. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.call_count = 0 def parse_macro(self, parser: Parser) -> None: macro_def = self.get_macro_def() - parser.visit_body(macro_def.body) + suffix = f"_{self.call_count}" if self.call_count > 0 else "" + self.call_count += 1 + with block_name_suffix_context(suffix): + parser.visit_body(macro_def.body) def macro(*args, hygienic: bool = True) -> Callable: