Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""IRBuilder for TIR"""

import contextlib
import functools
import inspect
import sys
import threading
from numbers import Integral
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -87,6 +89,35 @@
# pylint: enable=unused-import


_block_name_suffix = threading.local()


def _get_block_name_suffix() -> str:
"""Get the current block name suffix for macro expansion."""
return getattr(_block_name_suffix, "value", "")


@contextlib.contextmanager
def block_name_suffix_context(block_suffix: str):
"""Context manager to set block name suffix during macro expansion.

Parameters
----------
block_suffix : str
The suffix to append to block names (e.g., "_1", "_2").

Yields
------
None
"""
old_suffix = getattr(_block_name_suffix, "value", "")
_block_name_suffix.value = block_suffix
try:
yield
finally:
_block_name_suffix.value = old_suffix


def buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32",
Expand Down Expand Up @@ -352,6 +383,9 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
res : frame.BlockFrame
The BlockFrame.
"""
block_suffix = _get_block_name_suffix()
if block_suffix and name:
name = name + block_suffix
return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] # pylint: disable=no-member


Expand Down Expand Up @@ -2107,6 +2141,7 @@ def wrapped(*args, **kwargs):
"func_ret",
"match_buffer",
"block",
"block_name_suffix_context",
"init",
"where",
"reads",
Expand Down
20 changes: 17 additions & 3 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from ...ir_builder.tir import block_name_suffix_context, buffer, ptr
from .._core import parse, scan_macro, utils
from ..core.parser import Parser, ScriptMacro

Expand Down Expand Up @@ -90,11 +90,25 @@ def decorator_wrapper(func):


class TIRMacro(ScriptMacro):
"""Specialization of the ScriptMacro class for TIR."""
"""Specialization of the ScriptMacro class for TIR.

Attributes
----------
call_count : int
Counter for the number of times this macro has been invoked.
Used to generate unique block name suffixes.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.call_count = 0

def parse_macro(self, parser: Parser) -> None:
macro_def = self.get_macro_def()
parser.visit_body(macro_def.body)
suffix = f"_{self.call_count}" if self.call_count > 0 else ""
self.call_count += 1
with block_name_suffix_context(suffix):
parser.visit_body(macro_def.body)


def macro(*args, hygienic: bool = True) -> Callable:
Expand Down
Loading