Skip to content

Commit 7a38d5d

Browse files
committed
implement div and promotion
1 parent 7643c43 commit 7a38d5d

File tree

9 files changed

+308
-22
lines changed

9 files changed

+308
-22
lines changed

mlir_utils/ast/canonicalize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import logging
66
import types
7-
from abc import ABC
7+
from abc import ABC, abstractmethod
88
from dis import findlinestarts
99
from opcode import opmap
1010
from types import CodeType
@@ -101,7 +101,7 @@ class BytecodePatcher(ABC):
101101
def __init__(self, context=None):
102102
self.context = context
103103

104-
@property
104+
@abstractmethod
105105
def patch_bytecode(self, code: ConcreteBytecode, original_f) -> ConcreteBytecode:
106106
pass
107107

@@ -119,10 +119,12 @@ def patch_bytecode(f, patchers: list[type(BytecodePatcher)] = None):
119119

120120
class Canonicalizer(ABC):
121121
@property
122+
@abstractmethod
122123
def cst_transformers(self) -> list[StrictTransformer]:
123124
pass
124125

125126
@property
127+
@abstractmethod
126128
def bytecode_patchers(self) -> list[BytecodePatcher]:
127129
pass
128130

mlir_utils/dialects/ext/arith.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
from abc import abstractmethod
23
from copy import deepcopy
34
from functools import partialmethod, cached_property
45
from typing import Union, Optional
@@ -184,7 +185,7 @@ def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context):
184185
}
185186
if isinstance(predicate, Attribute):
186187
return predicate
187-
assert predicate in predicates, f"predicate {predicate} not in predicates"
188+
assert predicate in predicates, f"{predicate=} not in predicates"
188189
return IntegerAttr.get(
189190
IntegerType.get_signless(64, context=context), predicates[predicate]
190191
)
@@ -219,7 +220,7 @@ def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
219220
}
220221
if isinstance(predicate, Attribute):
221222
return predicate
222-
assert predicate in predicates, f"predicate {predicate} not in predicates"
223+
assert predicate in predicates, f"{predicate=} not in predicates"
223224
return IntegerAttr.get(
224225
IntegerType.get_signless(64, context=context), predicates[predicate]
225226
)
@@ -247,13 +248,14 @@ def _binary_op(
247248
if loc is None:
248249
loc = get_user_code_loc()
249250
if not isinstance(rhs, lhs.__class__):
250-
rhs = lhs.__class__(rhs, dtype=lhs.type)
251+
lhs, rhs = lhs.coerce(rhs)
252+
if lhs.type != rhs.type:
253+
raise ValueError(f"{lhs=} {rhs=} must have the same type.")
254+
255+
assert op in {"add", "sub", "mul", "cmp", "truediv", "floordiv", "mod"}
251256

252-
assert op in {"add", "sub", "mul", "cmp"}
253257
if op == "cmp":
254258
assert predicate is not None
255-
if lhs.type != rhs.type:
256-
raise ValueError(f"{lhs=} {rhs=} must have the same type.")
257259

258260
if lhs.fold() and lhs.fold():
259261
klass = lhs.__class__
@@ -267,15 +269,30 @@ def _binary_op(
267269
op = predicate
268270
op = operator.attrgetter(op)(operator)
269271
return klass(op(lhs, rhs), fold=True)
272+
273+
if op == "truediv":
274+
op = "div"
275+
if op == "mod":
276+
op = "rem"
277+
278+
op = op.capitalize()
279+
if _is_floating_point_type(lhs.dtype):
280+
if op == "Floordiv":
281+
raise ValueError(f"floordiv not supported for {lhs=}")
282+
op += "F"
283+
elif _is_integer_like_type(lhs.dtype):
284+
# TODO(max): this needs to all be regularized
285+
if "div" in op.lower() or "rem" in op.lower():
286+
if not lhs.dtype.is_signless:
287+
raise ValueError(f"{op.lower()}i not supported for {lhs=}")
288+
if op == "Floordiv":
289+
op = "FloorDiv"
290+
op += "S"
291+
op += "I"
270292
else:
271-
op = op.capitalize()
272-
lhs, rhs = lhs, rhs
273-
if _is_floating_point_type(lhs.dtype):
274-
op = getattr(arith_dialect, f"{op}FOp")
275-
elif _is_integer_like_type(lhs.dtype):
276-
op = getattr(arith_dialect, f"{op}IOp")
277-
else:
278-
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
293+
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
294+
295+
op = getattr(arith_dialect, f"{op}Op")
279296

280297
if predicate is not None:
281298
if _is_floating_point_type(lhs.dtype):
@@ -315,6 +332,15 @@ def is_constant(self) -> bool:
315332
self.owner.opview, arith_dialect.ConstantOp
316333
)
317334

335+
@property
336+
@abstractmethod
337+
def literal_value(self):
338+
pass
339+
340+
@abstractmethod
341+
def coerce(self, other) -> tuple["ArithValue", "ArithValue"]:
342+
pass
343+
318344
def fold(self) -> bool:
319345
return self.is_constant() and self._fold
320346

@@ -329,6 +355,10 @@ def __repr__(self):
329355
__add__ = partialmethod(_binary_op, op="add")
330356
__sub__ = partialmethod(_binary_op, op="sub")
331357
__mul__ = partialmethod(_binary_op, op="mul")
358+
__truediv__ = partialmethod(_binary_op, op="truediv")
359+
__floordiv__ = partialmethod(_binary_op, op="floordiv")
360+
__mod__ = partialmethod(_binary_op, op="mod")
361+
332362
__radd__ = partialmethod(_binary_op, op="add")
333363
__rsub__ = partialmethod(_binary_op, op="sub")
334364
__rmul__ = partialmethod(_binary_op, op="mul")
@@ -401,3 +431,10 @@ def __int__(self):
401431

402432
def __float__(self):
403433
return float(self.literal_value)
434+
435+
def coerce(self, other) -> tuple["Scalar", "Scalar"]:
436+
if isinstance(other, (int, float, bool)):
437+
other = Scalar(other, dtype=self.dtype)
438+
else:
439+
raise ValueError(f"can't coerce {other=} to Scalar")
440+
return self, other

mlir_utils/dialects/ext/func.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def body_builder_wrapper(self, *call_args):
7979
loc=self.loc,
8080
ip=self.ip,
8181
)
82-
func_op.regions[0].blocks.append(*input_types)
82+
arg_locs = [get_user_code_loc()] * len(sig.parameters)
83+
func_op.regions[0].blocks.append(*input_types, arg_locs=arg_locs)
8384
with InsertionPoint(func_op.regions[0].blocks[0]):
8485
results = get_result_or_results(
8586
self.body_builder(

mlir_utils/dialects/ext/tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,24 @@ def __setitem__(self, idx, source):
186186
previous_frame = inspect.currentframe().f_back
187187
_update_caller_vars(previous_frame, [self], [res])
188188

189+
def coerce(self, other) -> tuple["Tensor", "Tensor"]:
190+
if isinstance(other, np.ndarray):
191+
other = Tensor(other)
192+
return self, other
193+
elif _is_scalar(other):
194+
if not self.has_static_shape():
195+
raise ValueError(
196+
f"can't coerce {other=} because {self=} doesn't have static shape"
197+
)
198+
if isinstance(other, (int, float)):
199+
other = Tensor(np.full(self.shape, other), dtype=self.dtype)
200+
return self, other
201+
elif _is_scalar(other):
202+
other = tensor.splat(self.type, other)
203+
return self, other
204+
205+
raise ValueError(f"can't coerce unknown {other=}")
206+
189207

190208
@dataclass(frozen=True)
191209
class _Indexer:

mlir_utils/types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,27 @@
2222

2323
_index_t = lambda: IndexType.get()
2424
_bool_t = lambda: IntegerType.get_signless(1)
25+
2526
_i8_t = lambda: IntegerType.get_signless(8)
2627
_i16_t = lambda: IntegerType.get_signless(16)
2728
_i32_t = lambda: IntegerType.get_signless(32)
2829
_i64_t = lambda: IntegerType.get_signless(64)
30+
31+
_si8_t = lambda: IntegerType.get_signed(8)
32+
_si16_t = lambda: IntegerType.get_signed(16)
33+
_si32_t = lambda: IntegerType.get_signed(32)
34+
_si64_t = lambda: IntegerType.get_signed(64)
35+
36+
_ui8_t = lambda: IntegerType.get_unsigned(8)
37+
_ui16_t = lambda: IntegerType.get_unsigned(16)
38+
_ui32_t = lambda: IntegerType.get_unsigned(32)
39+
_ui64_t = lambda: IntegerType.get_unsigned(64)
40+
2941
_f16_t = lambda: F16Type.get()
3042
_f32_t = lambda: F32Type.get()
3143
_f64_t = lambda: F64Type.get()
3244
_bf16_t = lambda: BF16Type.get()
45+
3346
opaque_t = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
3447

3548

@@ -40,10 +53,22 @@ def _placeholder_opaque_t():
4053
_name_to_type = {
4154
"index_t": _index_t,
4255
"bool_t": _bool_t,
56+
4357
"i8_t": _i8_t,
4458
"i16_t": _i16_t,
4559
"i32_t": _i32_t,
4660
"i64_t": _i64_t,
61+
62+
"si8_t": _si8_t,
63+
"si16_t": _si16_t,
64+
"si32_t": _si32_t,
65+
"si64_t": _si64_t,
66+
67+
"ui8_t": _ui8_t,
68+
"ui16_t": _ui16_t,
69+
"ui32_t": _ui32_t,
70+
"ui64_t": _ui64_t,
71+
4772
"f16_t": _f16_t,
4873
"f32_t": _f32_t,
4974
"f64_t": _f64_t,

mlir_utils/util.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from functools import wraps
77
from pathlib import Path
8-
from typing import Callable, Sequence
8+
from typing import Callable, Sequence, Optional
99

1010
import mlir
1111
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
@@ -226,18 +226,22 @@ def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence):
226226
)
227227

228228

229-
def get_user_code_loc():
229+
def get_user_code_loc(user_base: Optional[Path] = None):
230230
import mlir_utils
231231
import mlir
232232

233233
mlir_utis_root_path = Path(mlir_utils.__path__[0])
234234
mlir_root_path = Path(mlir.__path__[0])
235235

236-
prev_frame = inspect.currentframe().f_back.f_back
236+
prev_frame = inspect.currentframe().f_back
237+
if user_base is None:
238+
user_base = Path(prev_frame.f_code.co_filename)
239+
237240
while (
238241
Path(prev_frame.f_code.co_filename).is_relative_to(mlir_utis_root_path)
239242
or Path(prev_frame.f_code.co_filename).is_relative_to(mlir_root_path)
240243
or Path(prev_frame.f_code.co_filename).is_relative_to(sys.prefix)
244+
or Path(prev_frame.f_code.co_filename).is_relative_to(user_base)
241245
):
242246
prev_frame = prev_frame.f_back
243247
frame_info = inspect.getframeinfo(prev_frame)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
test = ["pytest", "mlir-native-tools", "astpretty"]
1919
torch-mlir = ["torch-mlir-core"]
2020
jax = ["jax[cpu]", ]
21-
mlir = ["mlir-python-bindings==17.0.0.2023.7.26.12+25b8433"]
21+
mlir = ["mlir-python-bindings==17.0.0.2023.7.27.14+70aca7b1"]
2222

2323
[project.scripts]
2424
configure-mlir-python-utils = "mlir_utils:_configuration.configuration.configure_host_bindings"

tests/test_operator_overloading.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,51 @@
1313
pytest.mark.usefixtures("ctx")
1414

1515

16+
def test_arithmetic(ctx: MLIRContext):
17+
one = constant(1)
18+
two = constant(2)
19+
one + two
20+
one - two
21+
one / two
22+
one // two
23+
one % two
24+
25+
one = constant(1.0)
26+
two = constant(2.0)
27+
one + two
28+
one - two
29+
one / two
30+
try:
31+
one // two
32+
except ValueError as e:
33+
assert str(e) == "floordiv not supported for lhs=Scalar(%cst, f64)"
34+
one % two
35+
36+
ctx.module.operation.verify()
37+
filecheck(
38+
dedent(
39+
"""\
40+
module {
41+
%c1_i64 = arith.constant 1 : i64
42+
%c2_i64 = arith.constant 2 : i64
43+
%0 = arith.addi %c1_i64, %c2_i64 : i64
44+
%1 = arith.subi %c1_i64, %c2_i64 : i64
45+
%2 = arith.divsi %c1_i64, %c2_i64 : i64
46+
%3 = arith.floordivsi %c1_i64, %c2_i64 : i64
47+
%4 = arith.remsi %c1_i64, %c2_i64 : i64
48+
%cst = arith.constant 1.000000e+00 : f64
49+
%cst_0 = arith.constant 2.000000e+00 : f64
50+
%5 = arith.addf %cst, %cst_0 : f64
51+
%6 = arith.subf %cst, %cst_0 : f64
52+
%7 = arith.divf %cst, %cst_0 : f64
53+
%8 = arith.remf %cst, %cst_0 : f64
54+
}
55+
"""
56+
),
57+
ctx.module,
58+
)
59+
60+
1661
def test_tensor_arithmetic(ctx: MLIRContext):
1762
one = constant(1)
1863
assert isinstance(one, Scalar)
@@ -117,3 +162,47 @@ def test_arith_cmp(ctx: MLIRContext):
117162
),
118163
ctx.module,
119164
)
165+
166+
167+
def test_scalar_promotion(ctx: MLIRContext):
168+
one = constant(1)
169+
one + 2
170+
one - 2
171+
one / 2
172+
one // 2
173+
one % 2
174+
175+
one = constant(1.0)
176+
one + 2.0
177+
one - 2.0
178+
one / 2.0
179+
one % 2.0
180+
181+
ctx.module.operation.verify()
182+
correct = dedent(
183+
"""\
184+
module {
185+
%c1_i64 = arith.constant 1 : i64
186+
%c2_i64 = arith.constant 2 : i64
187+
%0 = arith.addi %c1_i64, %c2_i64 : i64
188+
%c2_i64_0 = arith.constant 2 : i64
189+
%1 = arith.subi %c1_i64, %c2_i64_0 : i64
190+
%c2_i64_1 = arith.constant 2 : i64
191+
%2 = arith.divsi %c1_i64, %c2_i64_1 : i64
192+
%c2_i64_2 = arith.constant 2 : i64
193+
%3 = arith.floordivsi %c1_i64, %c2_i64_2 : i64
194+
%c2_i64_3 = arith.constant 2 : i64
195+
%4 = arith.remsi %c1_i64, %c2_i64_3 : i64
196+
%cst = arith.constant 1.000000e+00 : f64
197+
%cst_4 = arith.constant 2.000000e+00 : f64
198+
%5 = arith.addf %cst, %cst_4 : f64
199+
%cst_5 = arith.constant 2.000000e+00 : f64
200+
%6 = arith.subf %cst, %cst_5 : f64
201+
%cst_6 = arith.constant 2.000000e+00 : f64
202+
%7 = arith.divf %cst, %cst_6 : f64
203+
%cst_7 = arith.constant 2.000000e+00 : f64
204+
%8 = arith.remf %cst, %cst_7 : f64
205+
}
206+
"""
207+
)
208+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)