Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
assert(isinstance(y, tvm.tir.Add))
assert(y.a == x)
"""
from typing import List, Optional, Union
from typing import List, Optional, Self, Union

import tvm_ffi
import tvm.ir._ffi_api
Expand All @@ -41,6 +41,9 @@
from .buffer import Buffer, DataProducer


numeric = Union[int, float, complex]


def convert(expr) -> PrimExpr:
return _ffi_api.convert(expr)

Expand Down Expand Up @@ -74,111 +77,111 @@ class ExprOp:

# TODO(tkonolige): use inspect to add source information to these objects

def __add__(self, other: PrimExpr) -> PrimExpr:
def __add__(self, other: Union[PrimExpr, numeric]) -> "Add":
return _generic.add(self, other)

def __radd__(self, other: PrimExpr) -> PrimExpr:
def __radd__(self, other: Union[PrimExpr, numeric]) -> "Add":
return _generic.add(other, self)

def __sub__(self, other: PrimExpr) -> PrimExpr:
def __sub__(self, other: Union[PrimExpr, numeric]) -> "Sub":
return _generic.subtract(self, other)

def __rsub__(self, other: PrimExpr) -> PrimExpr:
def __rsub__(self, other: Union[PrimExpr, numeric]) -> "Sub":
return _generic.subtract(other, self)

def __mul__(self, other: PrimExpr) -> PrimExpr:
def __mul__(self, other: Union[PrimExpr, numeric]) -> "Mul":
return _generic.multiply(self, other)

def __rmul__(self, other: PrimExpr) -> PrimExpr:
def __rmul__(self, other: Union[PrimExpr, numeric]) -> "Mul":
return _generic.multiply(other, self)

def __div__(self, other: PrimExpr) -> PrimExpr:
def __div__(self, other: Union[PrimExpr, numeric]) -> "Div":
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rdiv__(self, other: PrimExpr) -> PrimExpr:
def __rdiv__(self, other: Union[PrimExpr, numeric]) -> "Div":
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __truediv__(self, other: PrimExpr) -> PrimExpr:
def __truediv__(self, other: Union[PrimExpr, numeric]) -> "Div":
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rtruediv__(self, other: PrimExpr) -> PrimExpr:
def __rtruediv__(self, other: Union[PrimExpr, numeric]) -> "Div":
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __floordiv__(self, other: PrimExpr) -> PrimExpr:
def __floordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv":
return _generic.floordiv(self, other)

def __rfloordiv__(self, other: PrimExpr) -> PrimExpr:
def __rfloordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv":
return _generic.floordiv(other, self, None)

def __mod__(self, other: PrimExpr) -> PrimExpr:
def __mod__(self, other: Union[PrimExpr, numeric]) -> "Mod":
return _ffi_api._OpFloorMod(self, other, None) # type: ignore

def __rmod__(self, other: PrimExpr) -> PrimExpr:
def __rmod__(self, other: Union[PrimExpr, numeric]) -> "Mod":
return _ffi_api._OpFloorMod(other, self, None) # type: ignore

def __neg__(self) -> PrimExpr:
def __neg__(self) -> "Mul":
neg_one = const(-1, self.dtype) # type: ignore
return self.__mul__(neg_one)

def __lshift__(self, other: PrimExpr) -> PrimExpr:
def __lshift__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.left_shift(self, other, None) # type: ignore

def __rlshift__(self, other: PrimExpr) -> PrimExpr:
def __rlshift__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.left_shift(other, self, None) # type: ignore

def __rshift__(self, other: PrimExpr) -> PrimExpr:
def __rshift__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.right_shift(self, other, None) # type: ignore

def __rrshift__(self, other: PrimExpr) -> PrimExpr:
def __rrshift__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.right_shift(other, self, None) # type: ignore

def __and__(self, other: PrimExpr) -> PrimExpr:
def __and__(self, other: Union[PrimExpr, int, bool]) -> "Call":
return _ffi_api.bitwise_and(self, other, None) # type: ignore

def __rand__(self, other: PrimExpr) -> PrimExpr:
def __rand__(self, other: Union[PrimExpr, int, bool]) -> "Call":
return _ffi_api.bitwise_and(other, self, None) # type: ignore

def __or__(self, other: PrimExpr) -> PrimExpr:
def __or__(self, other: Union[PrimExpr, int, bool]) -> "Call":
return _ffi_api.bitwise_or(self, other, None) # type: ignore

def __ror__(self, other: PrimExpr) -> PrimExpr:
def __ror__(self, other: Union[PrimExpr, int, bool]) -> "Call":
return _ffi_api.bitwise_or(other, self, None) # type: ignore

def __xor__(self, other: PrimExpr) -> PrimExpr:
def __xor__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.bitwise_xor(self, other, None) # type: ignore

def __rxor__(self, other: PrimExpr) -> PrimExpr:
def __rxor__(self, other: Union[PrimExpr, int]) -> "Call":
return _ffi_api.bitwise_xor(other, self, None) # type: ignore

def __invert__(self) -> PrimExpr:
def __invert__(self) -> "Call":
if _dtype_is_float(self):
raise RuntimeError("Cannot use ~ operator on float type Expr.")
return _ffi_api.bitwise_not(self, None) # type: ignore

def __lt__(self, other: PrimExpr) -> PrimExpr:
def __lt__(self, other: Union[PrimExpr, numeric]) -> "LT":
return _ffi_api._OpLT(self, other, None) # type: ignore

def __le__(self, other: PrimExpr) -> PrimExpr:
def __le__(self, other: Union[PrimExpr, numeric]) -> "LE":
return _ffi_api._OpLE(self, other, None) # type: ignore

def __eq__(self, other: PrimExpr) -> PrimExpr:
def __eq__(self, other: Union[PrimExpr, numeric]) -> "EqualOp":
return EqualOp(self, other)

def __ne__(self, other: PrimExpr) -> PrimExpr:
def __ne__(self, other: Union[PrimExpr, numeric]) -> "NotEqualOp":
return NotEqualOp(self, other)

def __gt__(self, other: PrimExpr) -> PrimExpr:
def __gt__(self, other: Union[PrimExpr, numeric]) -> "GT":
return _ffi_api._OpGT(self, other, None) # type: ignore

def __ge__(self, other: PrimExpr) -> PrimExpr:
def __ge__(self, other: Union[PrimExpr, numeric]) -> "GE":
return _ffi_api._OpGE(self, other, None) # type: ignore

def __nonzero__(self):
Expand Down Expand Up @@ -208,7 +211,7 @@ def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool:
"""
return _ffi_api._OpEQ(self, other, span) # type: ignore

def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
def astype(self, dtype: str, span: Optional[Span] = None) -> Union["Cast", "Self"]:
"""Cast the expression to other type.

Parameters
Expand Down