Skip to content

Commit 3b6adc9

Browse files
authored
move operator overloading (#145)
1 parent 94061d5 commit 3b6adc9

File tree

3 files changed

+374
-388
lines changed

3 files changed

+374
-388
lines changed

tests/test_arith.py

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from textwrap import dedent
2+
13
import mlir.extras.types as T
24
import pytest
35

46
from mlir.extras.ast.canonicalize import canonicalize
57
from mlir.extras.dialects.ext import arith
8+
from mlir.extras.dialects.ext.arith import Scalar
69
from mlir.extras.dialects.ext.func import func
710

811
# noinspection PyUnresolvedReferences
@@ -27,3 +30,337 @@ def foo():
2730
row_l: T.f32() = 0.0
2831

2932
filecheck_with_comments(ctx.module)
33+
34+
35+
def test_arithmetic(ctx: MLIRContext):
36+
one = arith.constant(1)
37+
two = arith.constant(2)
38+
one + two
39+
one - two
40+
one / two
41+
one // two
42+
one % two
43+
44+
one = arith.constant(1.0)
45+
two = arith.constant(2.0)
46+
one + two
47+
one - two
48+
one / two
49+
try:
50+
one // two
51+
except ValueError as e:
52+
assert (
53+
str(e)
54+
== "floordiv not supported for lhs=Scalar(%cst = arith.constant 1.000000e+00 : f32)"
55+
)
56+
one % two
57+
58+
ctx.module.operation.verify()
59+
filecheck(
60+
dedent(
61+
"""\
62+
module {
63+
%c1_i32 = arith.constant 1 : i32
64+
%c2_i32 = arith.constant 2 : i32
65+
%0 = arith.addi %c1_i32, %c2_i32 : i32
66+
%1 = arith.subi %c1_i32, %c2_i32 : i32
67+
%2 = arith.divsi %c1_i32, %c2_i32 : i32
68+
%3 = arith.floordivsi %c1_i32, %c2_i32 : i32
69+
%4 = arith.remsi %c1_i32, %c2_i32 : i32
70+
%cst = arith.constant 1.000000e+00 : f32
71+
%cst_0 = arith.constant 2.000000e+00 : f32
72+
%5 = arith.addf %cst, %cst_0 : f32
73+
%6 = arith.subf %cst, %cst_0 : f32
74+
%7 = arith.divf %cst, %cst_0 : f32
75+
%8 = arith.remf %cst, %cst_0 : f32
76+
}
77+
"""
78+
),
79+
ctx.module,
80+
)
81+
82+
83+
def test_r_arithmetic(ctx: MLIRContext):
84+
one = arith.constant(1)
85+
two = arith.constant(2)
86+
one - two
87+
two - one
88+
89+
ctx.module.operation.verify()
90+
filecheck(
91+
dedent(
92+
"""\
93+
module {
94+
%c1_i32 = arith.constant 1 : i32
95+
%c2_i32 = arith.constant 2 : i32
96+
%0 = arith.subi %c1_i32, %c2_i32 : i32
97+
%1 = arith.subi %c2_i32, %c1_i32 : i32
98+
}
99+
"""
100+
),
101+
ctx.module,
102+
)
103+
104+
105+
def test_arith_cmp(ctx: MLIRContext):
106+
one = arith.constant(1)
107+
two = arith.constant(2)
108+
one < two
109+
one <= two
110+
one > two
111+
one >= two
112+
one == two
113+
one != two
114+
one & two
115+
one | two
116+
assert one._ne(two)
117+
assert not one._eq(two)
118+
119+
one = arith.constant(1.0)
120+
two = arith.constant(2.0)
121+
one < two
122+
one <= two
123+
one > two
124+
one >= two
125+
one == two
126+
one != two
127+
assert one._ne(two)
128+
assert not one._eq(two)
129+
130+
ctx.module.operation.verify()
131+
filecheck(
132+
dedent(
133+
"""\
134+
module {
135+
%c1_i32 = arith.constant 1 : i32
136+
%c2_i32 = arith.constant 2 : i32
137+
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
138+
%1 = arith.cmpi sle, %c1_i32, %c2_i32 : i32
139+
%2 = arith.cmpi sgt, %c1_i32, %c2_i32 : i32
140+
%3 = arith.cmpi sge, %c1_i32, %c2_i32 : i32
141+
%4 = arith.cmpi eq, %c1_i32, %c2_i32 : i32
142+
%5 = arith.cmpi ne, %c1_i32, %c2_i32 : i32
143+
%6 = arith.andi %c1_i32, %c2_i32 : i32
144+
%7 = arith.ori %c1_i32, %c2_i32 : i32
145+
%cst = arith.constant 1.000000e+00 : f32
146+
%cst_0 = arith.constant 2.000000e+00 : f32
147+
%8 = arith.cmpf olt, %cst, %cst_0 : f32
148+
%9 = arith.cmpf ole, %cst, %cst_0 : f32
149+
%10 = arith.cmpf ogt, %cst, %cst_0 : f32
150+
%11 = arith.cmpf oge, %cst, %cst_0 : f32
151+
%12 = arith.cmpf oeq, %cst, %cst_0 : f32
152+
%13 = arith.cmpf one, %cst, %cst_0 : f32
153+
}
154+
"""
155+
),
156+
ctx.module,
157+
)
158+
159+
160+
def test_arith_cmp_literals(ctx: MLIRContext):
161+
one = arith.constant(1)
162+
two = 2
163+
one < two
164+
one <= two
165+
one > two
166+
one >= two
167+
one == two
168+
one != two
169+
one & two
170+
one | two
171+
172+
one = arith.constant(1.0)
173+
two = 2.0
174+
one < two
175+
one <= two
176+
one > two
177+
one >= two
178+
one == two
179+
one != two
180+
181+
ctx.module.operation.verify()
182+
filecheck(
183+
dedent(
184+
"""\
185+
module {
186+
%c1_i32 = arith.constant 1 : i32
187+
%c2_i32 = arith.constant 2 : i32
188+
%0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
189+
%c2_i32_0 = arith.constant 2 : i32
190+
%1 = arith.cmpi sle, %c1_i32, %c2_i32_0 : i32
191+
%c2_i32_1 = arith.constant 2 : i32
192+
%2 = arith.cmpi sgt, %c1_i32, %c2_i32_1 : i32
193+
%c2_i32_2 = arith.constant 2 : i32
194+
%3 = arith.cmpi sge, %c1_i32, %c2_i32_2 : i32
195+
%c2_i32_3 = arith.constant 2 : i32
196+
%4 = arith.cmpi eq, %c1_i32, %c2_i32_3 : i32
197+
%c2_i32_4 = arith.constant 2 : i32
198+
%5 = arith.cmpi ne, %c1_i32, %c2_i32_4 : i32
199+
%c2_i32_5 = arith.constant 2 : i32
200+
%6 = arith.andi %c1_i32, %c2_i32_5 : i32
201+
%c2_i32_6 = arith.constant 2 : i32
202+
%7 = arith.ori %c1_i32, %c2_i32_6 : i32
203+
%cst = arith.constant 1.000000e+00 : f32
204+
%cst_7 = arith.constant 2.000000e+00 : f32
205+
%8 = arith.cmpf olt, %cst, %cst_7 : f32
206+
%cst_8 = arith.constant 2.000000e+00 : f32
207+
%9 = arith.cmpf ole, %cst, %cst_8 : f32
208+
%cst_9 = arith.constant 2.000000e+00 : f32
209+
%10 = arith.cmpf ogt, %cst, %cst_9 : f32
210+
%cst_10 = arith.constant 2.000000e+00 : f32
211+
%11 = arith.cmpf oge, %cst, %cst_10 : f32
212+
%cst_11 = arith.constant 2.000000e+00 : f32
213+
%12 = arith.cmpf oeq, %cst, %cst_11 : f32
214+
%cst_12 = arith.constant 2.000000e+00 : f32
215+
%13 = arith.cmpf one, %cst, %cst_12 : f32
216+
}
217+
"""
218+
),
219+
ctx.module,
220+
)
221+
222+
223+
def test_scalar_promotion(ctx: MLIRContext):
224+
one = arith.constant(1)
225+
one + 2
226+
one - 2
227+
one / 2
228+
one // 2
229+
one % 2
230+
231+
one = arith.constant(1.0)
232+
one + 2.0
233+
one - 2.0
234+
one / 2.0
235+
one % 2.0
236+
237+
ctx.module.operation.verify()
238+
# CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32
239+
# CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
240+
# CHECK: %[[VAL_1:.*]] = arith.addi %[[C1_I32]], %[[VAL_0]] : i32
241+
# CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
242+
# CHECK: %[[VAL_3:.*]] = arith.subi %[[C1_I32]], %[[VAL_2]] : i32
243+
# CHECK: %[[VAL_4:.*]] = arith.constant 2 : i32
244+
# CHECK: %[[VAL_5:.*]] = arith.divsi %[[C1_I32]], %[[VAL_4]] : i32
245+
# CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32
246+
# CHECK: %[[VAL_7:.*]] = arith.floordivsi %[[C1_I32]], %[[VAL_6]] : i32
247+
# CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32
248+
# CHECK: %[[VAL_9:.*]] = arith.remsi %[[C1_I32]], %[[VAL_8]] : i32
249+
# CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32
250+
# CHECK: %[[VAL_11:.*]] = arith.constant 2.000000e+00 : f32
251+
# CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : f32
252+
# CHECK: %[[VAL_13:.*]] = arith.constant 2.000000e+00 : f32
253+
# CHECK: %[[VAL_14:.*]] = arith.subf %[[VAL_10]], %[[VAL_13]] : f32
254+
# CHECK: %[[VAL_15:.*]] = arith.constant 2.000000e+00 : f32
255+
# CHECK: %[[VAL_16:.*]] = arith.divf %[[VAL_10]], %[[VAL_15]] : f32
256+
# CHECK: %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
257+
# CHECK: %[[VAL_18:.*]] = arith.remf %[[VAL_10]], %[[VAL_17]] : f32
258+
259+
filecheck_with_comments(ctx.module)
260+
261+
262+
def test_rscalar_promotion(ctx: MLIRContext):
263+
one = arith.constant(1)
264+
2 + one
265+
2 - one
266+
2 / one
267+
2 // one
268+
2 % one
269+
270+
one = arith.constant(1.0)
271+
2.0 + one
272+
2.0 - one
273+
2.0 / one
274+
2.0 % one
275+
276+
ctx.module.operation.verify()
277+
correct = dedent(
278+
"""\
279+
module {
280+
%c1_i32 = arith.constant 1 : i32
281+
%c2_i32 = arith.constant 2 : i32
282+
%0 = arith.addi %c2_i32, %c1_i32 : i32
283+
%c2_i32_0 = arith.constant 2 : i32
284+
%1 = arith.subi %c2_i32_0, %c1_i32 : i32
285+
%c2_i32_1 = arith.constant 2 : i32
286+
%2 = arith.divsi %c2_i32_1, %c1_i32 : i32
287+
%c2_i32_2 = arith.constant 2 : i32
288+
%3 = arith.floordivsi %c2_i32_2, %c1_i32 : i32
289+
%c2_i32_3 = arith.constant 2 : i32
290+
%4 = arith.remsi %c2_i32_3, %c1_i32 : i32
291+
%cst = arith.constant 1.000000e+00 : f32
292+
%cst_4 = arith.constant 2.000000e+00 : f32
293+
%5 = arith.addf %cst_4, %cst : f32
294+
%cst_5 = arith.constant 2.000000e+00 : f32
295+
%6 = arith.subf %cst_5, %cst : f32
296+
%cst_6 = arith.constant 2.000000e+00 : f32
297+
%7 = arith.divf %cst_6, %cst : f32
298+
%cst_7 = arith.constant 2.000000e+00 : f32
299+
%8 = arith.remf %cst_7, %cst : f32
300+
}
301+
"""
302+
)
303+
filecheck(correct, ctx.module)
304+
305+
306+
def test_arith_rcmp_literals(ctx: MLIRContext):
307+
one = 1
308+
two = arith.constant(2)
309+
one < two
310+
one <= two
311+
one > two
312+
one >= two
313+
one == two
314+
one != two
315+
one & two
316+
one | two
317+
318+
one = 1.0
319+
two = arith.constant(2.0)
320+
one < two
321+
one <= two
322+
one > two
323+
one >= two
324+
one == two
325+
one != two
326+
327+
ctx.module.operation.verify()
328+
filecheck(
329+
dedent(
330+
"""\
331+
module {
332+
%c2_i32 = arith.constant 2 : i32
333+
%c1_i32 = arith.constant 1 : i32
334+
%0 = arith.cmpi sgt, %c2_i32, %c1_i32 : i32
335+
%c1_i32_0 = arith.constant 1 : i32
336+
%1 = arith.cmpi sge, %c2_i32, %c1_i32_0 : i32
337+
%c1_i32_1 = arith.constant 1 : i32
338+
%2 = arith.cmpi slt, %c2_i32, %c1_i32_1 : i32
339+
%c1_i32_2 = arith.constant 1 : i32
340+
%3 = arith.cmpi sle, %c2_i32, %c1_i32_2 : i32
341+
%c1_i32_3 = arith.constant 1 : i32
342+
%4 = arith.cmpi eq, %c2_i32, %c1_i32_3 : i32
343+
%c1_i32_4 = arith.constant 1 : i32
344+
%5 = arith.cmpi ne, %c2_i32, %c1_i32_4 : i32
345+
%c1_i32_5 = arith.constant 1 : i32
346+
%6 = arith.andi %c1_i32_5, %c2_i32 : i32
347+
%c1_i32_6 = arith.constant 1 : i32
348+
%7 = arith.ori %c1_i32_6, %c2_i32 : i32
349+
%cst = arith.constant 2.000000e+00 : f32
350+
%cst_7 = arith.constant 1.000000e+00 : f32
351+
%8 = arith.cmpf ogt, %cst, %cst_7 : f32
352+
%cst_8 = arith.constant 1.000000e+00 : f32
353+
%9 = arith.cmpf oge, %cst, %cst_8 : f32
354+
%cst_9 = arith.constant 1.000000e+00 : f32
355+
%10 = arith.cmpf olt, %cst, %cst_9 : f32
356+
%cst_10 = arith.constant 1.000000e+00 : f32
357+
%11 = arith.cmpf ole, %cst, %cst_10 : f32
358+
%cst_11 = arith.constant 1.000000e+00 : f32
359+
%12 = arith.cmpf oeq, %cst, %cst_11 : f32
360+
%cst_12 = arith.constant 1.000000e+00 : f32
361+
%13 = arith.cmpf one, %cst, %cst_12 : f32
362+
}
363+
"""
364+
),
365+
ctx.module,
366+
)

0 commit comments

Comments
 (0)