Skip to content

Commit 516d919

Browse files
authored
Merge pull request #23 from zk-rabbit/feat/quadratic-ext-field
feat: introduce quadratic extension field
2 parents 7c358dd + 2b9643d commit 516d919

File tree

18 files changed

+890
-447
lines changed

18 files changed

+890
-447
lines changed

benchmark/field/mul_benchmark.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
func.func @mul(%arg0 : i256) -> i256 attributes { llvm.emit_c_interface } {
66
%0 = field.pf.encapsulate %arg0 : i256 -> !F
7-
%1 = field.pf.mul %0, %0 : !F
8-
%2 = field.pf.mul %0, %1 : !F
7+
%1 = field.mul %0, %0 : !F
8+
%2 = field.mul %0, %1 : !F
99
%3 = field.pf.extract %2 : !F -> i256
1010
return %3 : i256
1111
}
1212

1313
func.func @mont_mul(%arg0 : i256) -> i256 attributes { llvm.emit_c_interface } {
1414
%0 = field.pf.encapsulate %arg0 : i256 -> !F
15-
%1 = field.pf.mont_mul %0, %0 {montgomery = #mont} : !F
16-
%2 = field.pf.mont_mul %0, %1 {montgomery = #mont} : !F
15+
%1 = field.mont_mul %0, %0 {montgomery = #mont} : !F
16+
%2 = field.mont_mul %0, %1 {montgomery = #mont} : !F
1717
%3 = field.pf.extract %2 : !F -> i256
1818
return %3 : i256
1919
}

tests/Dialect/Field/prime_field_to_mod_arith.mlir

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: zkir-opt -prime-field-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
1+
// RUN: zkir-opt -field-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
22
!PF1 = !field.pf<3:i32>
33
!PFv = tensor<4x!PF1>
44
#root_elem = #field.pf_elem<2:i32> : !PF1
@@ -59,76 +59,76 @@ func.func @test_lower_extract_vec(%lhs : tensor<4x!PF1>) -> tensor<4xi32> {
5959
// CHECK-LABEL: @test_lower_to_mont
6060
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
6161
func.func @test_lower_to_mont(%lhs : !PF1) -> !PF1 {
62-
// CHECK-NOT: field.pf.to_mont
62+
// CHECK-NOT: field.to_mont
6363
// CHECK: %[[RES:.*]] = mod_arith.to_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
64-
%res = field.pf.to_mont %lhs {montgomery=#mont} : !PF1
64+
%res = field.to_mont %lhs {montgomery=#mont} : !PF1
6565
// CHECK: return %[[RES]] : [[T]]
6666
return %res : !PF1
6767
}
6868

6969
// CHECK-LABEL: @test_lower_to_mont_vec
7070
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
7171
func.func @test_lower_to_mont_vec(%lhs : !PFv) -> !PFv {
72-
// CHECK-NOT: field.pf.to_mont
72+
// CHECK-NOT: field.to_mont
7373
// CHECK: %[[RES:.*]] = mod_arith.to_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
74-
%res = field.pf.to_mont %lhs {montgomery=#mont} : !PFv
74+
%res = field.to_mont %lhs {montgomery=#mont} : !PFv
7575
// CHECK: return %[[RES]] : [[T]]
7676
return %res : !PFv
7777
}
7878

7979
// CHECK-LABEL: @test_lower_from_mont
8080
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
8181
func.func @test_lower_from_mont(%lhs : !PF1) -> !PF1 {
82-
// CHECK-NOT: field.pf.from_mont
82+
// CHECK-NOT: field.from_mont
8383
// CHECK: %[[RES:.*]] = mod_arith.from_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
84-
%res = field.pf.from_mont %lhs {montgomery=#mont} : !PF1
84+
%res = field.from_mont %lhs {montgomery=#mont} : !PF1
8585
// CHECK: return %[[RES]] : [[T]]
8686
return %res : !PF1
8787
}
8888

8989
// CHECK-LABEL: @test_lower_from_mont_vec
9090
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
9191
func.func @test_lower_from_mont_vec(%lhs : !PFv) -> !PFv {
92-
// CHECK-NOT: field.pf.from_mont
92+
// CHECK-NOT: field.from_mont
9393
// CHECK: %[[RES:.*]] = mod_arith.from_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
94-
%res = field.pf.from_mont %lhs {montgomery=#mont} : !PFv
94+
%res = field.from_mont %lhs {montgomery=#mont} : !PFv
9595
// CHECK: return %[[RES]] : [[T]]
9696
return %res : !PFv
9797
}
9898

9999
// CHECK-LABEL: @test_lower_inverse
100100
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
101101
func.func @test_lower_inverse(%lhs : !PF1) -> !PF1 {
102-
// CHECK-NOT: field.pf.inverse
102+
// CHECK-NOT: field.inverse
103103
// CHECK: %[[RES:.*]] = mod_arith.inverse %[[LHS]] : [[T]]
104-
%res = field.pf.inverse %lhs : !PF1
104+
%res = field.inverse %lhs : !PF1
105105
return %res : !PF1
106106
}
107107

108108
// CHECK-LABEL: @test_lower_inverse_vec
109109
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
110110
func.func @test_lower_inverse_vec(%lhs : !PFv) -> !PFv {
111-
// CHECK-NOT: field.pf.inverse
111+
// CHECK-NOT: field.inverse
112112
// CHECK: %[[RES:.*]] = mod_arith.inverse %[[LHS]] : [[T]]
113-
%res = field.pf.inverse %lhs : !PFv
113+
%res = field.inverse %lhs : !PFv
114114
return %res : !PFv
115115
}
116116

117117
// CHECK-LABEL: @test_lower_negate
118118
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
119119
func.func @test_lower_negate(%lhs : !PF1) -> !PF1 {
120-
// CHECK-NOT: field.pf.negate
120+
// CHECK-NOT: field.negate
121121
// CHECK: %[[RES:.*]] = mod_arith.negate %[[LHS]] : [[T]]
122-
%res = field.pf.negate %lhs : !PF1
122+
%res = field.negate %lhs : !PF1
123123
return %res : !PF1
124124
}
125125

126126
// CHECK-LABEL: @test_lower_negate_vec
127127
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
128128
func.func @test_lower_negate_vec(%lhs : !PF1) -> !PF1 {
129-
// CHECK-NOT: field.pf.negate
129+
// CHECK-NOT: field.negate
130130
// CHECK: %[[RES:.*]] = mod_arith.negate %[[LHS]] : [[T]]
131-
%res = field.pf.negate %lhs : !PF1
131+
%res = field.negate %lhs : !PF1
132132
return %res : !PF1
133133
}
134134

@@ -138,17 +138,17 @@ func.func @test_lower_add() -> !PF1 {
138138
// CHECK: %[[C0:.*]] = mod_arith.constant 4 : [[T]]
139139
%c0 = field.pf.constant 4 : !PF1
140140
// CHECK: %[[RES:.*]] = mod_arith.add %[[C0]], %[[C0]] : [[T]]
141-
%res = field.pf.add %c0, %c0 : !PF1
141+
%res = field.add %c0, %c0 : !PF1
142142
// CHECK: return %[[RES]] : [[T]]
143143
return %res : !PF1
144144
}
145145

146146
// CHECK-LABEL: @test_lower_add_vec
147147
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
148148
func.func @test_lower_add_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
149-
// CHECK-NOT: field.pf.add
149+
// CHECK-NOT: field.add
150150
// CHECK: %[[RES:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : tensor<4x!z3_i32_>
151-
%res = field.pf.add %lhs, %rhs : !PFv
151+
%res = field.add %lhs, %rhs : !PFv
152152
// CHECK: return %[[RES]] : [[T]]
153153
return %res : !PFv
154154
}
@@ -159,17 +159,17 @@ func.func @test_lower_double() -> !PF1 {
159159
// CHECK: %[[C0:.*]] = mod_arith.constant 4 : [[T]]
160160
%c0 = field.pf.constant 4 : !PF1
161161
// CHECK: %[[RES:.*]] = mod_arith.add %[[C0]], %[[C0]] : [[T]]
162-
%res = field.pf.double %c0 : !PF1
162+
%res = field.double %c0 : !PF1
163163
// CHECK: return %[[RES]] : [[T]]
164164
return %res : !PF1
165165
}
166166

167167
// CHECK-LABEL: @test_lower_double_vec
168168
// CHECK-SAME: (%[[VAL:.*]]: [[T:.*]]) -> [[T]] {
169169
func.func @test_lower_double_vec(%val : !PFv) -> !PFv {
170-
// CHECK-NOT: field.pf.double
170+
// CHECK-NOT: field.double
171171
// CHECK: %[[RES:.*]] = mod_arith.add %[[VAL]], %[[VAL]] : [[T]]
172-
%res = field.pf.double %val : !PFv
172+
%res = field.double %val : !PFv
173173
// CHECK: return %[[RES]] : [[T]]
174174
return %res : !PFv
175175
}
@@ -182,17 +182,17 @@ func.func @test_lower_sub() -> !PF1 {
182182
// CHECK: %[[C1:.*]] = mod_arith.constant 5 : [[T]]
183183
%c1 = field.pf.constant 5 : !PF1
184184
// CHECK: %[[RES:.*]] = mod_arith.sub %[[C0]], %[[C1]] : [[T]]
185-
%res = field.pf.sub %c0, %c1 : !PF1
185+
%res = field.sub %c0, %c1 : !PF1
186186
// CHECK: return %[[RES]] : [[T]]
187187
return %res : !PF1
188188
}
189189

190190
// CHECK-LABEL: @test_lower_sub_vec
191191
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
192192
func.func @test_lower_sub_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
193-
// CHECK-NOT: field.pf.sub
193+
// CHECK-NOT: field.sub
194194
// CHECK: %[[RES:.*]] = mod_arith.sub %[[LHS]], %[[RHS]] : [[T]]
195-
%res = field.pf.sub %lhs, %rhs : !PFv
195+
%res = field.sub %lhs, %rhs : !PFv
196196
// CHECK: return %[[RES]] : [[T]]
197197
return %res : !PFv
198198
}
@@ -203,17 +203,17 @@ func.func @test_lower_mul() -> !PF1 {
203203
// CHECK: %[[C0:.*]] = mod_arith.constant 4 : [[T]]
204204
%c0 = field.pf.constant 4 : !PF1
205205
// CHECK: %[[RES:.*]] = mod_arith.mul %[[C0]], %[[C0]] : [[T]]
206-
%res = field.pf.mul %c0, %c0 : !PF1
206+
%res = field.mul %c0, %c0 : !PF1
207207
// CHECK: return %[[RES]] : [[T]]
208208
return %res : !PF1
209209
}
210210

211211
// CHECK-LABEL: @test_lower_mul_vec
212212
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
213213
func.func @test_lower_mul_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
214-
// CHECK-NOT: field.pf.mul
214+
// CHECK-NOT: field.mul
215215
// CHECK: %[[RES:.*]] = mod_arith.mul %[[LHS]], %[[RHS]] : [[T]]
216-
%res = field.pf.mul %lhs, %rhs : !PFv
216+
%res = field.mul %lhs, %rhs : !PFv
217217
// CHECK: return %[[RES]] : [[T]]
218218
return %res : !PFv
219219
}
@@ -224,17 +224,17 @@ func.func @test_lower_square() -> !PF1 {
224224
// CHECK: %[[C0:.*]] = mod_arith.constant 4 : [[T]]
225225
%c0 = field.pf.constant 4 : !PF1
226226
// CHECK: %[[RES:.*]] = mod_arith.mul %[[C0]], %[[C0]] : [[T]]
227-
%res = field.pf.square %c0 : !PF1
227+
%res = field.square %c0 : !PF1
228228
// CHECK: return %[[RES]] : [[T]]
229229
return %res : !PF1
230230
}
231231

232232
// CHECK-LABEL: @test_lower_square_vec
233233
// CHECK-SAME: (%[[VAL:.*]]: [[T:.*]]) -> [[T]] {
234234
func.func @test_lower_square_vec(%val : !PFv) -> !PFv {
235-
// CHECK-NOT: field.pf.square
235+
// CHECK-NOT: field.square
236236
// CHECK: %[[RES:.*]] = mod_arith.mul %[[VAL]], %[[VAL]] : [[T]]
237-
%res = field.pf.square %val : !PFv
237+
%res = field.square %val : !PFv
238238
// CHECK: return %[[RES]] : [[T]]
239239
return %res : !PFv
240240
}
@@ -245,17 +245,17 @@ func.func @test_lower_mont_mul() -> !PF1 {
245245
// CHECK: %[[C0:.*]] = mod_arith.constant 2 : [[T]]
246246
%c0 = field.pf.constant 2 : !PF1
247247
// CHECK: %[[RES:.*]] = mod_arith.mont_mul %[[C0]], %[[C0]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
248-
%res = field.pf.mont_mul %c0, %c0 {montgomery = #mont} : !PF1
248+
%res = field.mont_mul %c0, %c0 {montgomery = #mont} : !PF1
249249
// CHECK: return %[[RES]] : [[T]]
250250
return %res : !PF1
251251
}
252252

253253
// CHECK-LABEL: @test_lower_mont_mul_vec
254254
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
255255
func.func @test_lower_mont_mul_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
256-
// CHECK-NOT: field.pf.mont_mul
256+
// CHECK-NOT: field.mont_mul
257257
// CHECK: %[[RES:.*]] = mod_arith.mont_mul %[[LHS]], %[[RHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
258-
%res = field.pf.mont_mul %lhs, %rhs {montgomery = #mont} : !PFv
258+
%res = field.mont_mul %lhs, %rhs {montgomery = #mont} : !PFv
259259
// CHECK: return %[[RES]] : [[T]]
260260
return %res : !PFv
261261
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// RUN: zkir-opt -field-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
2+
!mod = !mod_arith.int<7:i32>
3+
#mont = #mod_arith.montgomery<!mod>
4+
!PF = !field.pf<7:i32>
5+
#beta = #field.pf_elem<6:i32> : !PF
6+
!QF = !field.f2<!PF, #beta>
7+
8+
// CHECK-LABEL: @test_lower_inverse
9+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> ([[T]], [[T]]) {
10+
func.func @test_lower_inverse(%arg0: !QF) -> !QF {
11+
// CHECK-NOT: field.inverse
12+
%0 = field.inverse %arg0 : !QF
13+
return %0 : !QF
14+
}
15+
16+
// CHECK-LABEL: @test_lower_double
17+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> ([[T]], [[T]]) {
18+
func.func @test_lower_double(%arg0: !QF) -> !QF {
19+
// CHECK-NOT: field.double
20+
%0 = field.double %arg0 : !QF
21+
return %0 : !QF
22+
}
23+
24+
// CHECK-LABEL: @test_lower_square
25+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> ([[T]], [[T]]) {
26+
func.func @test_lower_square(%arg0: !QF) -> !QF {
27+
// CHECK-NOT: field.square
28+
%0 = field.square %arg0 : !QF
29+
return %0 : !QF
30+
}
31+
32+
// CHECK-LABEL: @test_lower_add
33+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]], %[[ARG2:.*]]: [[T]], %[[ARG3:.*]]: [[T]]) -> ([[T]], [[T]]) {
34+
func.func @test_lower_add(%arg0: !QF, %arg1: !QF) -> !QF {
35+
// CHECK: %[[C0:.*]] = mod_arith.add %[[ARG0]], %[[ARG2]] : [[T]]
36+
// CHECK: %[[C1:.*]] = mod_arith.add %[[ARG1]], %[[ARG3]] : [[T]]
37+
// CHECK: return %[[C0]], %[[C1]] : [[T]], [[T]]
38+
%0 = field.add %arg0, %arg1 : !QF
39+
return %0 : !QF
40+
}
41+
42+
// CHECK-LABEL: @test_lower_mul
43+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]], %[[ARG2:.*]]: [[T]], %[[ARG3:.*]]: [[T]]) -> ([[T]], [[T]]) {
44+
func.func @test_lower_mul(%arg0: !QF, %arg1: !QF) -> !QF {
45+
// CHECK: %[[BETA:.*]] = mod_arith.constant 6 : [[T]]
46+
// CHECK: %[[V0:.*]] = mod_arith.mul %[[ARG0]], %[[ARG2]] : [[T]]
47+
// CHECK: %[[V1:.*]] = mod_arith.mul %[[ARG1]], %[[ARG3]] : [[T]]
48+
// CHECK: %[[BETATIMESV1:.*]] = mod_arith.mul %[[BETA]], %[[V1]] : [[T]]
49+
// CHECK: %[[C0:.*]] = mod_arith.add %[[V0]], %[[BETATIMESV1]] : [[T]]
50+
// CHECK: %[[SUMLHS:.*]] = mod_arith.add %[[ARG0]], %[[ARG1]] : [[T]]
51+
// CHECK: %[[SUMRHS:.*]] = mod_arith.add %[[ARG2]], %[[ARG3]] : [[T]]
52+
// CHECK: %[[SUMPRODUCT:.*]] = mod_arith.mul %[[SUMLHS]], %[[SUMRHS]] : [[T]]
53+
// CHECK: %[[TMP:.*]] = mod_arith.sub %[[SUMPRODUCT]], %[[V0]] : [[T]]
54+
// CHECK: %[[C1:.*]] = mod_arith.sub %[[TMP]], %[[V1]] : [[T]]
55+
// CHECK: return %[[C0]], %[[C1]] : [[T]], [[T]]
56+
%0 = field.mul %arg0, %arg1 : !QF
57+
return %0 : !QF
58+
}
59+
60+
// CHECK-LABEL: @test_lower_from_elements
61+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> tensor<2x2x[[T]]> {
62+
func.func @test_lower_from_elements(%arg0: !PF, %arg1: !PF) -> tensor<2x!QF> {
63+
%0 = field.f2.constant %arg0, %arg1 : !QF
64+
%1 = field.f2.constant %arg1, %arg0 : !QF
65+
tensor.from_elements %arg0, %arg1 : tensor<2x!PF>
66+
// CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG0]] : tensor<2x2x[[T]]>
67+
%2 = tensor.from_elements %0, %1 : tensor<2x!QF>
68+
// CHECK: return %[[TENSOR]] : tensor<2x2x[[T]]>
69+
return %2 : tensor<2x!QF>
70+
}
71+
72+
// CHECK-LABEL: @test_lower_from_mont
73+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> ([[T]], [[T]]) {
74+
func.func @test_lower_from_mont(%arg0: !QF) -> !QF {
75+
%0 = field.from_mont %arg0 {montgomery=#mont} : !QF
76+
return %0 : !QF
77+
}
78+
79+
// CHECK-LABEL: @test_lower_to_mont
80+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]]) -> ([[T]], [[T]]) {
81+
func.func @test_lower_to_mont(%arg0: !QF) -> !QF {
82+
%0 = field.to_mont %arg0 {montgomery=#mont} : !QF
83+
return %0 : !QF
84+
}
85+
86+
// CHECK-LABEL: @test_lower_mont_mul
87+
// CHECK-SAME: (%[[ARG0:.*]]: [[T:.*]], %[[ARG1:.*]]: [[T]], %[[ARG2:.*]]: [[T]], %[[ARG3:.*]]: [[T]]) -> ([[T]], [[T]]) {
88+
func.func @test_lower_mont_mul(%arg0: !QF, %arg1: !QF) -> !QF {
89+
%0 = field.mont_mul %arg0, %arg1 {montgomery=#mont} : !QF
90+
return %0 : !QF
91+
}
92+
93+
// CHECK-LABEL: @test_lower_tensor_extract
94+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2x2x[[T:.*]]>) -> ([[T]], [[T]]) {
95+
func.func @test_lower_tensor_extract(%arg0: tensor<3x2x!QF>) -> !QF {
96+
// CHECK: %[[I1:.*]] = arith.constant 1 : index
97+
%i1 = arith.constant 1 : index
98+
99+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
100+
// CHECK: %[[VALUE0:.*]] = tensor.extract %[[ARG0]][%[[I1]], %[[I1]], %[[C0]]] : tensor<3x2x2x[[T]]>
101+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
102+
// CHECK: %[[VALUE1:.*]] = tensor.extract %[[ARG0]][%[[I1]], %[[I1]], %[[C1]]] : tensor<3x2x2x[[T]]>
103+
%1 = tensor.extract %arg0[%i1, %i1] : tensor<3x2x!QF>
104+
// CHECK: return %[[VALUE0]], %[[VALUE1]] : [[T]], [[T]]
105+
return %1 : !QF
106+
}
107+
108+
// CHECK-LABEL: @test_lower_memref
109+
// CHECK-SAME: (%[[ARG0:.*]]: memref<3x2x2x[[T:.*]]>) -> ([[T]], [[T]]) {
110+
func.func @test_lower_memref(%arg0: memref<3x2x!QF>) -> !QF {
111+
%t = bufferization.to_tensor %arg0 : memref<3x2x!QF> to tensor<3x2x!QF>
112+
%i1 = arith.constant 1 : index
113+
%1 = tensor.extract %t[%i1, %i1] : tensor<3x2x!QF>
114+
return %1 : !QF
115+
}

tests/Dialect/ModArith/mod_arith_to_arith.mlir

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,32 +71,6 @@ func.func @test_lower_extract_vec(%lhs : !Zpv) -> tensor<4xi32> {
7171
return %res : tensor<4xi32>
7272
}
7373

74-
// CHECK-LABEL: @test_lower_reduce
75-
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
76-
func.func @test_lower_reduce(%lhs : !Zp) -> !Zp {
77-
// CHECK-NOT: mod_arith.reduce
78-
// CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]]
79-
// CHECK: %[[REMS:.*]] = arith.remsi %[[LHS]], %[[CMOD]] : [[T]]
80-
// CHECK: %[[ADD:.*]] = arith.addi %[[REMS]], %[[CMOD]] : [[T]]
81-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
82-
// CHECK: return %[[REM]] : [[T]]
83-
%res = mod_arith.reduce %lhs: !Zp
84-
return %res : !Zp
85-
}
86-
87-
// CHECK-LABEL: @test_lower_reduce_vec
88-
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
89-
func.func @test_lower_reduce_vec(%lhs : !Zpv) -> !Zpv {
90-
// CHECK-NOT: mod_arith.reduce
91-
// CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]]
92-
// CHECK: %[[REMS:.*]] = arith.remsi %[[LHS]], %[[CMOD]] : [[T]]
93-
// CHECK: %[[ADD:.*]] = arith.addi %[[REMS]], %[[CMOD]] : [[T]]
94-
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[T]]
95-
// CHECK: return %[[REM]] : [[T]]
96-
%res = mod_arith.reduce %lhs: !Zpv
97-
return %res : !Zpv
98-
}
99-
10074
// CHECK-LABEL: @test_lower_inverse
10175
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[T]] {
10276
func.func @test_lower_inverse(%lhs : !Zp) -> !Zp {

0 commit comments

Comments
 (0)