diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b74e9954d9cf..d9d8819813ed 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -32,7 +32,16 @@ structural_hash, ) from .container import Array, Map -from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelaxExpr +from .expr import ( + BaseExpr, + GlobalVar, + PrimExpr, + PrimIntExpr, + PrimFloatExpr, + PrimLogicalExpr, + Range, + RelaxExpr, +) from .function import BaseFunc, CallingConv from .global_info import GlobalInfo, DummyGlobalInfo, VDevice from .module import IRModule diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 19abb6bd1eae..56045822e239 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -16,7 +16,7 @@ # under the License. """Common expressions data structures in the IR.""" from numbers import Number -from typing import Optional +from typing import Optional, Union import tvm import tvm_ffi @@ -44,6 +44,11 @@ class PrimExpr(BaseExpr): dtype: str +PrimIntExpr = Union[PrimExpr, int] +PrimFloatExpr = Union[PrimExpr, float] +PrimLogicalExpr = Union[PrimExpr, int, bool] + + @tvm_ffi.register_object("ir.RelaxExpr") class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" @@ -115,11 +120,11 @@ class Range(Node, Scriptable): Parameters ---------- - begin : PrimExpr + begin : PrimIntExpr The begin value of the range when end is None. Otherwise it is the length of the range. - end : Optional[PrimExpr] + end : Optional[PrimIntExpr] The end value of the range. span : Optional[Span] @@ -136,13 +141,13 @@ class Range(Node, Scriptable): span: Optional[Span] def __init__( - self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None + self, begin: PrimIntExpr, end: Optional[PrimIntExpr] = None, span: Optional[Span] = None ) -> None: self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod def from_min_extent( - min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None + min_value: PrimIntExpr, extent: PrimIntExpr, span: Optional[Span] = None ) -> "Range": """Construct a Range by min and extent. @@ -150,10 +155,10 @@ def from_min_extent( Parameters ---------- - min_value : PrimExpr + min_value : PrimIntExpr The minimum value of the range. - extent : PrimExpr + extent : PrimIntExpr The extent of the range. span : Optional[Span] diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a08e66789fa3..7defbf1c3708 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -22,7 +22,7 @@ import sys import threading from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload # isort: off from typing_extensions import Literal @@ -39,7 +39,7 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr +from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr, PrimIntExpr, PrimLogicalExpr from tvm.tir import op as _tir_op from tvm.tir import type_annotation @@ -119,34 +119,34 @@ def block_name_suffix_context(block_suffix: str): def buffer( - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr], dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = 0, offset_factor: int = 0, buffer_type: str = "", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer declaration function. Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -161,7 +161,7 @@ def buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -271,16 +271,16 @@ def func_ret(ret_type: Type) -> Type: def match_buffer( param: Union[Var, BufferLoad, BufferRegion], - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, + shape: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr]] = None, dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer match function. @@ -305,19 +305,19 @@ def match_buffer( param : Union[Var, BufferLoad, BufferRegion] The parameter of the PrimFunc to match. - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr]] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -332,7 +332,7 @@ def match_buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -346,7 +346,7 @@ def match_buffer( shape = [region.extent for region in param.region] else: raise ValueError("Shape must be specified when binding input param") - shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + shape = (shape,) if isinstance(shape, (PrimExpr, Integral, int)) else shape if strides is not None: idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] @@ -400,12 +400,12 @@ def init() -> frame.BlockInitFrame: return _ffi_api.Init() # type: ignore[attr-defined] # pylint: disable=no-member -def where(predicate: Union[PrimExpr, int]) -> None: +def where(predicate: PrimLogicalExpr) -> None: """The block predicate statement. Parameters ---------- - predicate : Union[PrimExpr, Literal[0, 1]] + predicate : PrimLogicalExpr The predicate condition. """ if isinstance(predicate, bool): @@ -418,27 +418,27 @@ def where(predicate: Union[PrimExpr, int]) -> None: _ffi_api.Where(predicate) # type: ignore[attr-defined] # pylint: disable=no-member -def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: +def reads(*buffer_slices: Union[BufferRegion, BufferLoad]) -> None: """The block buffer region reading statement. Parameters ---------- - buffer_slices : List[Union[BufferRegion, BufferLoad]] + buffer_slices : Union[BufferRegion, BufferLoad] The array of buffer regions to read. """ if len(buffer_slices) == 1: if isinstance(buffer_slices[0], tuple): - buffer_slices = list(buffer_slices[0]) + buffer_slices = list(buffer_slices[0]) # type: ignore[assignment] elif isinstance(buffer_slices[0], list): buffer_slices = buffer_slices[0] # type: ignore[assignment] else: - buffer_slices = [buffer_slices[0]] + buffer_slices = [buffer_slices[0]] # type: ignore[assignment] else: buffer_slices = list(buffer_slices) # type: ignore[assignment] _ffi_api.Reads(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member -def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: +def writes(*buffer_slices: Union[BufferRegion, BufferLoad]) -> None: """The block buffer region writing statement. Parameters @@ -448,11 +448,11 @@ def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: """ if len(buffer_slices) == 1: if isinstance(buffer_slices[0], tuple): - buffer_slices = list(buffer_slices[0]) + buffer_slices = list(buffer_slices[0]) # type: ignore[assignment] elif isinstance(buffer_slices[0], list): buffer_slices = buffer_slices[0] # type: ignore[assignment] else: - buffer_slices = [buffer_slices[0]] + buffer_slices = [buffer_slices[0]] # type: ignore[assignment] else: buffer_slices = list(buffer_slices) # type: ignore[assignment] _ffi_api.Writes(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member @@ -470,34 +470,34 @@ def block_attr(attrs: Dict[str, Any]) -> None: def alloc_buffer( - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr], PrimIntExpr], dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer alllocation function. Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[PrimIntExpr], Tuple[PrimIntExpr], PrimIntExpr] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -512,7 +512,7 @@ def alloc_buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -539,12 +539,14 @@ def alloc_buffer( ) -def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: +def _as_range( + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], +) -> ir.Range: """The range constructor. Parameters ---------- - dom : Union[Range, List[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain. Returns @@ -562,7 +564,7 @@ def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): - return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(IntImm(dom.dtype, 0), dom) # type: ignore[attr-defined] # pylint: disable=no-member return ir.Range(0, dom) @@ -571,18 +573,18 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The spatial block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -599,18 +601,18 @@ def spatial( @staticmethod def reduce( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The reduced block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -627,18 +629,18 @@ def reduce( @staticmethod def scan( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The scanning block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -655,18 +657,18 @@ def scan( @staticmethod def opaque( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The opaque block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -681,8 +683,27 @@ def opaque( _as_range(dom), binding, dtype ) + @overload @staticmethod - def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: + def remap(kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32") -> Var: + ... + + @overload + @staticmethod + def remap(kinds: str, bindings: Tuple[PrimExpr, ...], dtype: str = "int32") -> List[Var]: + ... + + @overload + @staticmethod + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> List[Var]: + ... + + @staticmethod + def remap( + kinds: str, + bindings: Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr], + dtype: str = "int32", + ) -> Union[List[Var], Var]: """The block axis remapping function. Parameters @@ -690,7 +711,7 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L kinds : str The types of the iteration variables. - bindings : List[PrimExpr] + bindings : Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr] The binding values of the iteration variables. dtype : str @@ -698,9 +719,10 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L Returns ------- - res : Var + res : Union[Var, List[Var]] The iteration variables. """ + bindings = (bindings,) if isinstance(bindings, PrimExpr) else bindings iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member kinds, bindings, dtype ) @@ -711,26 +733,26 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L def serial( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The serial For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -741,33 +763,33 @@ def serial( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def parallel( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The parallel For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -778,33 +800,33 @@ def parallel( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def vectorized( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The vectorized For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -815,33 +837,33 @@ def vectorized( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def unroll( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The unrolled For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -852,33 +874,33 @@ def unroll( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def thread_binding( - start: PrimExpr, - stop: PrimExpr = None, - thread: str = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, + thread: Optional[str] = None, *, - annotations: Dict[str, Any] = None, + annotations: Optional[Dict[str, Any]] = None, ) -> frame.ForFrame: """The thread-binding For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - thread : str + thread : Optional[str] The thread for loop variable to bind. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. Returns @@ -892,13 +914,13 @@ def thread_binding( thread = stop stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] else: start = 0 elif stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] else: start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member @@ -906,7 +928,7 @@ def thread_binding( ) -def grid(*extents: PrimExpr) -> frame.ForFrame: +def grid(*extents: PrimIntExpr) -> frame.ForFrame: """The grid For statement. Parameters @@ -922,12 +944,14 @@ def grid(*extents: PrimExpr) -> frame.ForFrame: return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name +def Assert( # pylint: disable=invalid-name + condition: PrimLogicalExpr, message: str +) -> frame.AssertFrame: """Create an assertion statement. Parameters ---------- - condition : PrimExpr + condition : PrimLogicalExpr The PrimExpr to test. message : str diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 0a598e5e9bb9..b685c094f2d4 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-import, redefined-builtin """Namespace for Tensor-level IR""" -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr from tvm.runtime import const from .buffer import Buffer, decl_buffer, DataProducer diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f5476230c19b..77fa47b35017 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,12 +27,12 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import List, Optional, Union +from typing import List, Optional, TypeVar, Union import tvm_ffi import tvm.ir._ffi_api from tvm import ir -from tvm.ir import Op, PrimExpr +from tvm.ir import Op, PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr from tvm.ir.base import Span from tvm.runtime import Object, ObjectConvertible, Scriptable, DataType, DataTypeCode, const @@ -69,116 +69,119 @@ def _dtype_is_float(value): ) # type: ignore +Self = TypeVar("Self", bound="ExprOp") + + class ExprOp: """Operator overloading for Expr like expressions.""" # TODO(tkonolige): use inspect to add source information to these objects - def __add__(self, other: PrimExpr) -> PrimExpr: + def __add__(self, other: PrimFloatExpr) -> "Add": return _generic.add(self, other) - def __radd__(self, other: PrimExpr) -> PrimExpr: + def __radd__(self, other: PrimFloatExpr) -> "Add": return _generic.add(other, self) - def __sub__(self, other: PrimExpr) -> PrimExpr: + def __sub__(self, other: PrimFloatExpr) -> "Sub": return _generic.subtract(self, other) - def __rsub__(self, other: PrimExpr) -> PrimExpr: + def __rsub__(self, other: PrimFloatExpr) -> "Sub": return _generic.subtract(other, self) - def __mul__(self, other: PrimExpr) -> PrimExpr: + def __mul__(self, other: PrimFloatExpr) -> "Mul": return _generic.multiply(self, other) - def __rmul__(self, other: PrimExpr) -> PrimExpr: + def __rmul__(self, other: PrimFloatExpr) -> "Mul": return _generic.multiply(other, self) - def __div__(self, other: PrimExpr) -> PrimExpr: + def __div__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other: PrimExpr) -> PrimExpr: + def __rdiv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other: PrimExpr) -> PrimExpr: + def __truediv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other: PrimExpr) -> PrimExpr: + def __rtruediv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other: PrimExpr) -> PrimExpr: + def __floordiv__(self, other: PrimFloatExpr) -> "FloorDiv": return _generic.floordiv(self, other) - def __rfloordiv__(self, other: PrimExpr) -> PrimExpr: + def __rfloordiv__(self, other: PrimFloatExpr) -> "FloorDiv": return _generic.floordiv(other, self, None) - def __mod__(self, other: PrimExpr) -> PrimExpr: + def __mod__(self, other: PrimFloatExpr) -> "Mod": return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other: PrimExpr) -> PrimExpr: + def __rmod__(self, other: PrimFloatExpr) -> "Mod": return _ffi_api._OpFloorMod(other, self, None) # type: ignore - def __neg__(self) -> PrimExpr: + def __neg__(self) -> "Mul": neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) - def __lshift__(self, other: PrimExpr) -> PrimExpr: + def __lshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.left_shift(self, other, None) # type: ignore - def __rlshift__(self, other: PrimExpr) -> PrimExpr: + def __rlshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.left_shift(other, self, None) # type: ignore - def __rshift__(self, other: PrimExpr) -> PrimExpr: + def __rshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.right_shift(self, other, None) # type: ignore - def __rrshift__(self, other: PrimExpr) -> PrimExpr: + def __rrshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.right_shift(other, self, None) # type: ignore - def __and__(self, other: PrimExpr) -> PrimExpr: + def __and__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_and(self, other, None) # type: ignore - def __rand__(self, other: PrimExpr) -> PrimExpr: + def __rand__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_and(other, self, None) # type: ignore - def __or__(self, other: PrimExpr) -> PrimExpr: + def __or__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_or(self, other, None) # type: ignore - def __ror__(self, other: PrimExpr) -> PrimExpr: + def __ror__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other: PrimExpr) -> PrimExpr: + def __xor__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other: PrimExpr) -> PrimExpr: + def __rxor__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_xor(other, self, None) # type: ignore - def __invert__(self) -> PrimExpr: + def __invert__(self) -> "Call": if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other: PrimExpr) -> PrimExpr: + def __lt__(self, other: PrimFloatExpr) -> "LT": return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other: PrimExpr) -> PrimExpr: + def __le__(self, other: PrimFloatExpr) -> "LE": return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other: PrimExpr) -> PrimExpr: + def __eq__(self, other: PrimFloatExpr) -> "EqualOp": return EqualOp(self, other) - def __ne__(self, other: PrimExpr) -> PrimExpr: + def __ne__(self, other: PrimFloatExpr) -> "NotEqualOp": return NotEqualOp(self, other) - def __gt__(self, other: PrimExpr) -> PrimExpr: + def __gt__(self, other: PrimFloatExpr) -> "GT": return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other: PrimExpr) -> PrimExpr: + def __ge__(self, other: PrimFloatExpr) -> "GE": return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): @@ -208,7 +211,7 @@ def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool: """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: + def astype(self: Self, dtype: str, span: Optional[Span] = None) -> Union["Cast", Self]: """Cast the expression to other type. Parameters