Skip to content

Commit 7643c43

Browse files
committed
start to port tensor indexing
1 parent e184645 commit 7643c43

File tree

11 files changed

+1011
-80
lines changed

11 files changed

+1011
-80
lines changed

mlir_utils/dialects/ext/arith.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import operator
12
from copy import deepcopy
23
from functools import partialmethod, cached_property
34
from typing import Union, Optional
@@ -116,7 +117,7 @@ class ArithValueMeta(type(Value)):
116117
"""
117118

118119
def __call__(cls, *args, **kwargs):
119-
"""Orchestrate the Python object protocol for Indexing dialect extension
120+
"""Orchestrate the Python object protocol for mlir
120121
values in order to handle wrapper arbitrary Python objects.
121122
122123
Args:
@@ -132,7 +133,7 @@ def __call__(cls, *args, **kwargs):
132133
if len(args) != 1:
133134
raise ValueError("Only one non-kw arg supported.")
134135
arg = args[0]
135-
arg_copy = None
136+
fold = None
136137
if isinstance(arg, (OpView, Operation, Value)):
137138
# wrap an already created Value (or op the produces a Value)
138139
if isinstance(arg, (Operation, OpView)):
@@ -143,13 +144,15 @@ def __call__(cls, *args, **kwargs):
143144
dtype = kwargs.get("dtype")
144145
if dtype is not None and not isinstance(dtype, Type):
145146
raise ValueError(f"{dtype=} is expected to be an ir.Type.")
147+
fold = kwargs.get("fold")
148+
if fold is not None and not isinstance(fold, bool):
149+
raise ValueError(f"{fold=} is expected to be a bool.")
146150
# If we're wrapping a numpy array (effectively a tensor literal),
147151
# then we want to make sure no one else has access to that memory.
148152
# Otherwise, the array will get funneled down to DenseElementsAttr.get,
149153
# which by default (through the Python buffer protocol) does not copy;
150154
# see mlir/lib/Bindings/Python/IRAttributes.cpp#L556
151-
arg_copy = deepcopy(arg)
152-
return constant(arg_copy, dtype)
155+
val = constant(deepcopy(arg), dtype)
153156
else:
154157
raise NotImplementedError(f"{cls.__name__} doesn't support wrapping {arg}.")
155158

@@ -161,7 +164,7 @@ def __call__(cls, *args, **kwargs):
161164
# the Python object protocol; first an object is new'ed and then
162165
# it is init'ed. Note we pass arg_copy here in case a subclass wants to
163166
# inspect the literal.
164-
cls.__init__(cls_obj, val)
167+
cls.__init__(cls_obj, val, fold=fold)
165168
return cls_obj
166169

167170

@@ -252,14 +255,28 @@ def _binary_op(
252255
if lhs.type != rhs.type:
253256
raise ValueError(f"{lhs=} {rhs=} must have the same type.")
254257

255-
op = op.capitalize()
256-
lhs, rhs = lhs, rhs
257-
if _is_floating_point_type(lhs.dtype):
258-
op = getattr(arith_dialect, f"{op}FOp")
259-
elif _is_integer_like_type(lhs.dtype):
260-
op = getattr(arith_dialect, f"{op}IOp")
258+
if lhs.fold() and lhs.fold():
259+
klass = lhs.__class__
260+
# if both operands are constants (results of an arith.constant op)
261+
# then both have a literal value (i.e. Python value).
262+
lhs, rhs = lhs.literal_value, rhs.literal_value
263+
# if we're folding constants (self._fold = True) then we just carry out
264+
# the corresponding operation on the literal values; e.g., operator.add.
265+
# note this is the same as op = operator.__dict__[op].
266+
if predicate is not None:
267+
op = predicate
268+
op = operator.attrgetter(op)(operator)
269+
return klass(op(lhs, rhs), fold=True)
261270
else:
262-
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
271+
op = op.capitalize()
272+
lhs, rhs = lhs, rhs
273+
if _is_floating_point_type(lhs.dtype):
274+
op = getattr(arith_dialect, f"{op}FOp")
275+
elif _is_integer_like_type(lhs.dtype):
276+
op = getattr(arith_dialect, f"{op}IOp")
277+
else:
278+
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
279+
263280
if predicate is not None:
264281
if _is_floating_point_type(lhs.dtype):
265282
# ordered comparison - see above
@@ -289,9 +306,18 @@ class ArithValue(Value, metaclass=ArithValueMeta):
289306
Value.__init__
290307
"""
291308

292-
def __init__(self, val):
309+
def __init__(self, val, *, fold: Optional[bool] = None):
310+
self._fold = fold if fold is not None else False
293311
super().__init__(val)
294312

313+
def is_constant(self) -> bool:
314+
return isinstance(self.owner, Operation) and isinstance(
315+
self.owner.opview, arith_dialect.ConstantOp
316+
)
317+
318+
def fold(self) -> bool:
319+
return self.is_constant() and self._fold
320+
295321
def __str__(self):
296322
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
297323

@@ -306,8 +332,29 @@ def __repr__(self):
306332
__radd__ = partialmethod(_binary_op, op="add")
307333
__rsub__ = partialmethod(_binary_op, op="sub")
308334
__rmul__ = partialmethod(_binary_op, op="mul")
309-
__eq__ = partialmethod(_binary_op, op="cmp", predicate="eq")
310-
__ne__ = partialmethod(_binary_op, op="cmp", predicate="ne")
335+
336+
def __eq__(self, other):
337+
if not isinstance(other, self.__class__):
338+
try:
339+
other = self.__class__(other, dtype=self.type)
340+
except NotImplementedError as e:
341+
assert "doesn't support wrapping" in str(e)
342+
return False
343+
if self is other:
344+
return True
345+
return _binary_op(self, other, op="cmp", predicate="eq")
346+
347+
def __ne__(self, other):
348+
if not isinstance(other, self.__class__):
349+
try:
350+
other = self.__class__(other, dtype=self.type)
351+
except NotImplementedError as e:
352+
assert "doesn't support wrapping" in str(e)
353+
return True
354+
if self is other:
355+
return False
356+
return _binary_op(self, other, op="cmp", predicate="ne")
357+
311358
__le__ = partialmethod(_binary_op, op="cmp", predicate="le")
312359
__lt__ = partialmethod(_binary_op, op="cmp", predicate="lt")
313360
__ge__ = partialmethod(_binary_op, op="cmp", predicate="ge")
@@ -342,3 +389,15 @@ def isinstance(other: Value):
342389
or _is_index_type(other.type)
343390
or _is_complex_type(other.type)
344391
)
392+
393+
@cached_property
394+
def literal_value(self) -> Union[int, float, bool]:
395+
if not self.is_constant():
396+
raise ValueError("Can't build literal from non-constant Scalar")
397+
return self.owner.opview.literal_value
398+
399+
def __int__(self):
400+
return int(self.literal_value)
401+
402+
def __float__(self):
403+
return float(self.literal_value)

mlir_utils/dialects/ext/func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def body_builder_wrapper(self, *call_args):
8282
func_op.regions[0].blocks.append(*input_types)
8383
with InsertionPoint(func_op.regions[0].blocks[0]):
8484
results = get_result_or_results(
85-
self.body_builder(*func_op.regions[0].blocks[0].arguments)
85+
self.body_builder(
86+
*[maybe_cast(a) for a in func_op.regions[0].blocks[0].arguments]
87+
)
8688
)
8789
if results is not None:
8890
if isinstance(results, (tuple, list)):

mlir_utils/dialects/ext/scf.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@ def _for(
5656

5757
for_ = region_op(_for, terminator=yield__)
5858

59+
# def range_(
60+
# start,
61+
# stop=None,
62+
# step=None,
63+
# iter_args: Optional[Sequence[Value]] = None,
64+
# *,
65+
# loc=None,
66+
# ip=None,
67+
# ):
68+
# for_op = _for(start, stop, step, iter_args, loc=loc, ip=ip)
69+
# iv = maybe_cast(for_op.induction_variable)
70+
# for_iter_args = tuple(map(maybe_cast, for_op.inner_iter_args))
71+
# results = tuple(map(maybe_cast, for_op.results_))
72+
# with InsertionPoint(for_op.body):
73+
# previous_frame = inspect.currentframe().f_back
74+
# _update_caller_vars(previous_frame, iter_args, for_iter_args)
75+
#
76+
# if len(results) > 1:
77+
# yield iv, results
78+
# elif len(results) == 1:
79+
# yield iv, results[0]
80+
# else:
81+
# yield iv
82+
5983

6084
def range_(
6185
start,

0 commit comments

Comments
 (0)