Skip to content
1 change: 1 addition & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
util.py

# The main _mlir module has submodules: include stubs from each.
_mlir_libs/_mlir/__init__.pyi
Expand Down
42 changes: 20 additions & 22 deletions mlir/python/mlir/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
from ..util import is_integer_type, is_index_type, is_float_type
from array import array as _array
from typing import overload


try:
from ..ir import *
from ._ods_common import (
Expand All @@ -21,26 +23,6 @@
raise RuntimeError("Error loading imports from extension module") from e


def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True


def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)


def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])


def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])


@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
Expand Down Expand Up @@ -96,9 +78,9 @@ def value(self):

@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
if is_integer_type(self.type) or is_index_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
elif is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
Expand All @@ -108,3 +90,19 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))


def index_cast(
in_: Value,
to: Type = None,
*,
out: Type = None,
loc: Location = None,
ip: InsertionPoint = None,
) -> Value:
if bool(to) != bool(out):
raise ValueError("either `to` or `out` must be set but not both")
res_type = out or to
if res_type is None:
res_type = IndexType.get()
return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip))
106 changes: 38 additions & 68 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from typing import Callable, Dict, List, Sequence, Tuple, Union

from .....ir import *
from .....util import (
is_complex_type,
is_float_type,
is_index_type,
is_integer_type,
get_floating_point_width,
)

from .... import func
from .... import linalg
Expand Down Expand Up @@ -412,21 +419,21 @@ def _cast(
)
if operand.type == to_type:
return operand
if _is_integer_type(to_type):
if is_integer_type(to_type):
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
elif _is_floating_point_type(to_type):
elif is_float_type(to_type):
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)

def _cast_to_integer(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
if _is_floating_point_type(operand_type):
if is_float_type(operand_type):
if is_unsigned_cast:
return arith.FPToUIOp(to_type, operand).result
return arith.FPToSIOp(to_type, operand).result
if _is_index_type(operand_type):
if is_index_type(operand_type):
return arith.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
Expand All @@ -444,13 +451,13 @@ def _cast_to_floating_point(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
operand_type = operand.type
if _is_integer_type(operand_type):
if is_integer_type(operand_type):
if is_unsigned_cast:
return arith.UIToFPOp(to_type, operand).result
return arith.SIToFPOp(to_type, operand).result
# Assume FloatType.
to_width = _get_floating_point_width(to_type)
from_width = _get_floating_point_width(operand_type)
to_width = get_floating_point_width(to_type)
from_width = get_floating_point_width(operand_type)
if to_width > from_width:
return arith.ExtFOp(to_type, operand).result
elif to_width < from_width:
Expand All @@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)

def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")

def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")

def _unary_abs(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.AbsFOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")

def _unary_ceil(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.CeilOp(x).result
raise NotImplementedError("Unsupported 'ceil' operand: {x}")

def _unary_floor(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return math.FloorOp(x).result
raise NotImplementedError("Unsupported 'floor' operand: {x}")

def _unary_negf(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if is_float_type(x.type):
return arith.NegFOp(x).result
if _is_complex_type(x.type):
if is_complex_type(x.type):
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")

def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")

def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")

def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if is_complex_type(lhs.type):
return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")

def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")

def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")

def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")

def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")

Expand Down Expand Up @@ -609,40 +616,3 @@ def _add_type_mapping(
)
type_mapping[name] = element_or_self_type
block_arg_types.append(element_or_self_type)


def _is_complex_type(t: Type) -> bool:
return ComplexType.isinstance(t)


def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
return (
F64Type.isinstance(t)
or F32Type.isinstance(t)
or F16Type.isinstance(t)
or BF16Type.isinstance(t)
)


def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)


def _is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)


def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
if F64Type.isinstance(t):
return 64
if F32Type.isinstance(t):
return 32
if F16Type.isinstance(t):
return 16
if BF16Type.isinstance(t):
return 16
raise NotImplementedError(f"Unhandled floating point type switch {t}")
5 changes: 3 additions & 2 deletions mlir/python/mlir/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@

from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from .arith import ConstantOp
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
from ..util import is_integer_like_type


def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
and is_integer_like_type(i.type)
)


Expand Down
Loading