1
+ from textwrap import dedent
2
+
1
3
import mlir .extras .types as T
2
4
import pytest
3
5
4
6
from mlir .extras .ast .canonicalize import canonicalize
5
7
from mlir .extras .dialects .ext import arith
8
+ from mlir .extras .dialects .ext .arith import Scalar
6
9
from mlir .extras .dialects .ext .func import func
7
10
8
11
# noinspection PyUnresolvedReferences
@@ -27,3 +30,337 @@ def foo():
27
30
row_l : T .f32 () = 0.0
28
31
29
32
filecheck_with_comments (ctx .module )
33
+
34
+
35
+ def test_arithmetic (ctx : MLIRContext ):
36
+ one = arith .constant (1 )
37
+ two = arith .constant (2 )
38
+ one + two
39
+ one - two
40
+ one / two
41
+ one // two
42
+ one % two
43
+
44
+ one = arith .constant (1.0 )
45
+ two = arith .constant (2.0 )
46
+ one + two
47
+ one - two
48
+ one / two
49
+ try :
50
+ one // two
51
+ except ValueError as e :
52
+ assert (
53
+ str (e )
54
+ == "floordiv not supported for lhs=Scalar(%cst = arith.constant 1.000000e+00 : f32)"
55
+ )
56
+ one % two
57
+
58
+ ctx .module .operation .verify ()
59
+ filecheck (
60
+ dedent (
61
+ """\
62
+ module {
63
+ %c1_i32 = arith.constant 1 : i32
64
+ %c2_i32 = arith.constant 2 : i32
65
+ %0 = arith.addi %c1_i32, %c2_i32 : i32
66
+ %1 = arith.subi %c1_i32, %c2_i32 : i32
67
+ %2 = arith.divsi %c1_i32, %c2_i32 : i32
68
+ %3 = arith.floordivsi %c1_i32, %c2_i32 : i32
69
+ %4 = arith.remsi %c1_i32, %c2_i32 : i32
70
+ %cst = arith.constant 1.000000e+00 : f32
71
+ %cst_0 = arith.constant 2.000000e+00 : f32
72
+ %5 = arith.addf %cst, %cst_0 : f32
73
+ %6 = arith.subf %cst, %cst_0 : f32
74
+ %7 = arith.divf %cst, %cst_0 : f32
75
+ %8 = arith.remf %cst, %cst_0 : f32
76
+ }
77
+ """
78
+ ),
79
+ ctx .module ,
80
+ )
81
+
82
+
83
+ def test_r_arithmetic (ctx : MLIRContext ):
84
+ one = arith .constant (1 )
85
+ two = arith .constant (2 )
86
+ one - two
87
+ two - one
88
+
89
+ ctx .module .operation .verify ()
90
+ filecheck (
91
+ dedent (
92
+ """\
93
+ module {
94
+ %c1_i32 = arith.constant 1 : i32
95
+ %c2_i32 = arith.constant 2 : i32
96
+ %0 = arith.subi %c1_i32, %c2_i32 : i32
97
+ %1 = arith.subi %c2_i32, %c1_i32 : i32
98
+ }
99
+ """
100
+ ),
101
+ ctx .module ,
102
+ )
103
+
104
+
105
+ def test_arith_cmp (ctx : MLIRContext ):
106
+ one = arith .constant (1 )
107
+ two = arith .constant (2 )
108
+ one < two
109
+ one <= two
110
+ one > two
111
+ one >= two
112
+ one == two
113
+ one != two
114
+ one & two
115
+ one | two
116
+ assert one ._ne (two )
117
+ assert not one ._eq (two )
118
+
119
+ one = arith .constant (1.0 )
120
+ two = arith .constant (2.0 )
121
+ one < two
122
+ one <= two
123
+ one > two
124
+ one >= two
125
+ one == two
126
+ one != two
127
+ assert one ._ne (two )
128
+ assert not one ._eq (two )
129
+
130
+ ctx .module .operation .verify ()
131
+ filecheck (
132
+ dedent (
133
+ """\
134
+ module {
135
+ %c1_i32 = arith.constant 1 : i32
136
+ %c2_i32 = arith.constant 2 : i32
137
+ %0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
138
+ %1 = arith.cmpi sle, %c1_i32, %c2_i32 : i32
139
+ %2 = arith.cmpi sgt, %c1_i32, %c2_i32 : i32
140
+ %3 = arith.cmpi sge, %c1_i32, %c2_i32 : i32
141
+ %4 = arith.cmpi eq, %c1_i32, %c2_i32 : i32
142
+ %5 = arith.cmpi ne, %c1_i32, %c2_i32 : i32
143
+ %6 = arith.andi %c1_i32, %c2_i32 : i32
144
+ %7 = arith.ori %c1_i32, %c2_i32 : i32
145
+ %cst = arith.constant 1.000000e+00 : f32
146
+ %cst_0 = arith.constant 2.000000e+00 : f32
147
+ %8 = arith.cmpf olt, %cst, %cst_0 : f32
148
+ %9 = arith.cmpf ole, %cst, %cst_0 : f32
149
+ %10 = arith.cmpf ogt, %cst, %cst_0 : f32
150
+ %11 = arith.cmpf oge, %cst, %cst_0 : f32
151
+ %12 = arith.cmpf oeq, %cst, %cst_0 : f32
152
+ %13 = arith.cmpf one, %cst, %cst_0 : f32
153
+ }
154
+ """
155
+ ),
156
+ ctx .module ,
157
+ )
158
+
159
+
160
+ def test_arith_cmp_literals (ctx : MLIRContext ):
161
+ one = arith .constant (1 )
162
+ two = 2
163
+ one < two
164
+ one <= two
165
+ one > two
166
+ one >= two
167
+ one == two
168
+ one != two
169
+ one & two
170
+ one | two
171
+
172
+ one = arith .constant (1.0 )
173
+ two = 2.0
174
+ one < two
175
+ one <= two
176
+ one > two
177
+ one >= two
178
+ one == two
179
+ one != two
180
+
181
+ ctx .module .operation .verify ()
182
+ filecheck (
183
+ dedent (
184
+ """\
185
+ module {
186
+ %c1_i32 = arith.constant 1 : i32
187
+ %c2_i32 = arith.constant 2 : i32
188
+ %0 = arith.cmpi slt, %c1_i32, %c2_i32 : i32
189
+ %c2_i32_0 = arith.constant 2 : i32
190
+ %1 = arith.cmpi sle, %c1_i32, %c2_i32_0 : i32
191
+ %c2_i32_1 = arith.constant 2 : i32
192
+ %2 = arith.cmpi sgt, %c1_i32, %c2_i32_1 : i32
193
+ %c2_i32_2 = arith.constant 2 : i32
194
+ %3 = arith.cmpi sge, %c1_i32, %c2_i32_2 : i32
195
+ %c2_i32_3 = arith.constant 2 : i32
196
+ %4 = arith.cmpi eq, %c1_i32, %c2_i32_3 : i32
197
+ %c2_i32_4 = arith.constant 2 : i32
198
+ %5 = arith.cmpi ne, %c1_i32, %c2_i32_4 : i32
199
+ %c2_i32_5 = arith.constant 2 : i32
200
+ %6 = arith.andi %c1_i32, %c2_i32_5 : i32
201
+ %c2_i32_6 = arith.constant 2 : i32
202
+ %7 = arith.ori %c1_i32, %c2_i32_6 : i32
203
+ %cst = arith.constant 1.000000e+00 : f32
204
+ %cst_7 = arith.constant 2.000000e+00 : f32
205
+ %8 = arith.cmpf olt, %cst, %cst_7 : f32
206
+ %cst_8 = arith.constant 2.000000e+00 : f32
207
+ %9 = arith.cmpf ole, %cst, %cst_8 : f32
208
+ %cst_9 = arith.constant 2.000000e+00 : f32
209
+ %10 = arith.cmpf ogt, %cst, %cst_9 : f32
210
+ %cst_10 = arith.constant 2.000000e+00 : f32
211
+ %11 = arith.cmpf oge, %cst, %cst_10 : f32
212
+ %cst_11 = arith.constant 2.000000e+00 : f32
213
+ %12 = arith.cmpf oeq, %cst, %cst_11 : f32
214
+ %cst_12 = arith.constant 2.000000e+00 : f32
215
+ %13 = arith.cmpf one, %cst, %cst_12 : f32
216
+ }
217
+ """
218
+ ),
219
+ ctx .module ,
220
+ )
221
+
222
+
223
+ def test_scalar_promotion (ctx : MLIRContext ):
224
+ one = arith .constant (1 )
225
+ one + 2
226
+ one - 2
227
+ one / 2
228
+ one // 2
229
+ one % 2
230
+
231
+ one = arith .constant (1.0 )
232
+ one + 2.0
233
+ one - 2.0
234
+ one / 2.0
235
+ one % 2.0
236
+
237
+ ctx .module .operation .verify ()
238
+ # CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32
239
+ # CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
240
+ # CHECK: %[[VAL_1:.*]] = arith.addi %[[C1_I32]], %[[VAL_0]] : i32
241
+ # CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
242
+ # CHECK: %[[VAL_3:.*]] = arith.subi %[[C1_I32]], %[[VAL_2]] : i32
243
+ # CHECK: %[[VAL_4:.*]] = arith.constant 2 : i32
244
+ # CHECK: %[[VAL_5:.*]] = arith.divsi %[[C1_I32]], %[[VAL_4]] : i32
245
+ # CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32
246
+ # CHECK: %[[VAL_7:.*]] = arith.floordivsi %[[C1_I32]], %[[VAL_6]] : i32
247
+ # CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32
248
+ # CHECK: %[[VAL_9:.*]] = arith.remsi %[[C1_I32]], %[[VAL_8]] : i32
249
+ # CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32
250
+ # CHECK: %[[VAL_11:.*]] = arith.constant 2.000000e+00 : f32
251
+ # CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : f32
252
+ # CHECK: %[[VAL_13:.*]] = arith.constant 2.000000e+00 : f32
253
+ # CHECK: %[[VAL_14:.*]] = arith.subf %[[VAL_10]], %[[VAL_13]] : f32
254
+ # CHECK: %[[VAL_15:.*]] = arith.constant 2.000000e+00 : f32
255
+ # CHECK: %[[VAL_16:.*]] = arith.divf %[[VAL_10]], %[[VAL_15]] : f32
256
+ # CHECK: %[[VAL_17:.*]] = arith.constant 2.000000e+00 : f32
257
+ # CHECK: %[[VAL_18:.*]] = arith.remf %[[VAL_10]], %[[VAL_17]] : f32
258
+
259
+ filecheck_with_comments (ctx .module )
260
+
261
+
262
+ def test_rscalar_promotion (ctx : MLIRContext ):
263
+ one = arith .constant (1 )
264
+ 2 + one
265
+ 2 - one
266
+ 2 / one
267
+ 2 // one
268
+ 2 % one
269
+
270
+ one = arith .constant (1.0 )
271
+ 2.0 + one
272
+ 2.0 - one
273
+ 2.0 / one
274
+ 2.0 % one
275
+
276
+ ctx .module .operation .verify ()
277
+ correct = dedent (
278
+ """\
279
+ module {
280
+ %c1_i32 = arith.constant 1 : i32
281
+ %c2_i32 = arith.constant 2 : i32
282
+ %0 = arith.addi %c2_i32, %c1_i32 : i32
283
+ %c2_i32_0 = arith.constant 2 : i32
284
+ %1 = arith.subi %c2_i32_0, %c1_i32 : i32
285
+ %c2_i32_1 = arith.constant 2 : i32
286
+ %2 = arith.divsi %c2_i32_1, %c1_i32 : i32
287
+ %c2_i32_2 = arith.constant 2 : i32
288
+ %3 = arith.floordivsi %c2_i32_2, %c1_i32 : i32
289
+ %c2_i32_3 = arith.constant 2 : i32
290
+ %4 = arith.remsi %c2_i32_3, %c1_i32 : i32
291
+ %cst = arith.constant 1.000000e+00 : f32
292
+ %cst_4 = arith.constant 2.000000e+00 : f32
293
+ %5 = arith.addf %cst_4, %cst : f32
294
+ %cst_5 = arith.constant 2.000000e+00 : f32
295
+ %6 = arith.subf %cst_5, %cst : f32
296
+ %cst_6 = arith.constant 2.000000e+00 : f32
297
+ %7 = arith.divf %cst_6, %cst : f32
298
+ %cst_7 = arith.constant 2.000000e+00 : f32
299
+ %8 = arith.remf %cst_7, %cst : f32
300
+ }
301
+ """
302
+ )
303
+ filecheck (correct , ctx .module )
304
+
305
+
306
+ def test_arith_rcmp_literals (ctx : MLIRContext ):
307
+ one = 1
308
+ two = arith .constant (2 )
309
+ one < two
310
+ one <= two
311
+ one > two
312
+ one >= two
313
+ one == two
314
+ one != two
315
+ one & two
316
+ one | two
317
+
318
+ one = 1.0
319
+ two = arith .constant (2.0 )
320
+ one < two
321
+ one <= two
322
+ one > two
323
+ one >= two
324
+ one == two
325
+ one != two
326
+
327
+ ctx .module .operation .verify ()
328
+ filecheck (
329
+ dedent (
330
+ """\
331
+ module {
332
+ %c2_i32 = arith.constant 2 : i32
333
+ %c1_i32 = arith.constant 1 : i32
334
+ %0 = arith.cmpi sgt, %c2_i32, %c1_i32 : i32
335
+ %c1_i32_0 = arith.constant 1 : i32
336
+ %1 = arith.cmpi sge, %c2_i32, %c1_i32_0 : i32
337
+ %c1_i32_1 = arith.constant 1 : i32
338
+ %2 = arith.cmpi slt, %c2_i32, %c1_i32_1 : i32
339
+ %c1_i32_2 = arith.constant 1 : i32
340
+ %3 = arith.cmpi sle, %c2_i32, %c1_i32_2 : i32
341
+ %c1_i32_3 = arith.constant 1 : i32
342
+ %4 = arith.cmpi eq, %c2_i32, %c1_i32_3 : i32
343
+ %c1_i32_4 = arith.constant 1 : i32
344
+ %5 = arith.cmpi ne, %c2_i32, %c1_i32_4 : i32
345
+ %c1_i32_5 = arith.constant 1 : i32
346
+ %6 = arith.andi %c1_i32_5, %c2_i32 : i32
347
+ %c1_i32_6 = arith.constant 1 : i32
348
+ %7 = arith.ori %c1_i32_6, %c2_i32 : i32
349
+ %cst = arith.constant 2.000000e+00 : f32
350
+ %cst_7 = arith.constant 1.000000e+00 : f32
351
+ %8 = arith.cmpf ogt, %cst, %cst_7 : f32
352
+ %cst_8 = arith.constant 1.000000e+00 : f32
353
+ %9 = arith.cmpf oge, %cst, %cst_8 : f32
354
+ %cst_9 = arith.constant 1.000000e+00 : f32
355
+ %10 = arith.cmpf olt, %cst, %cst_9 : f32
356
+ %cst_10 = arith.constant 1.000000e+00 : f32
357
+ %11 = arith.cmpf ole, %cst, %cst_10 : f32
358
+ %cst_11 = arith.constant 1.000000e+00 : f32
359
+ %12 = arith.cmpf oeq, %cst, %cst_11 : f32
360
+ %cst_12 = arith.constant 1.000000e+00 : f32
361
+ %13 = arith.cmpf one, %cst, %cst_12 : f32
362
+ }
363
+ """
364
+ ),
365
+ ctx .module ,
366
+ )
0 commit comments