Skip to content

Commit 3d947e3

Browse files
authored
refac(elliptic_curve): refactor elliptic curve dialect lowering (#25)
* refac(elliptic_curve): fix type conversions * refac(elliptic_curve): fix `PointOp` lowering * refac(elliptic_curve): fix `ConvertPointTypeOp` lowering * refac(elliptic_curve): fix `AddOp` lowering * refac(elliptic_curve): fix `DoubleOp` lowering * refac(elliptic_curve): fix `NegateOp` lowering * refac(elliptic_curve): fix `SubOp` lowering * refac(elliptic_curve): fix `ScalarMulOp` lowering * refac(elliptic_curve): fix `MSMOp` lowering * refac(elliptic_curve): fix `ExtractOp` lowering * refac(elliptic_curve): remove `PointSetOp` + `PointSetExtractOp` * refac(elliptic_curve): remove helper functions * refac(elliptic_curve): remove `const` + `&` for `Value` and `ValueRange` parameters
1 parent 516d919 commit 3d947e3

File tree

13 files changed

+473
-743
lines changed

13 files changed

+473
-743
lines changed

tests/Dialect/EllipticCurve/elliptic_curve_syntax.mlir

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -251,26 +251,6 @@ func.func @test_scalar_mul() {
251251
return
252252
}
253253

254-
// CHECK-LABEL: @test_point_set
255-
func.func @test_point_set() {
256-
// CHECK: %[[VAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
257-
%var1 = field.pf.constant 1 : !PF
258-
// CHECK: %[[VAR5:.*]] = field.pf.constant 5 : ![[PF]]
259-
%var5 = field.pf.constant 5 : !PF
260-
261-
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point %[[VAR1]], %[[VAR5]] : ![[PF]] -> ![[AF:.*]]
262-
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
263-
// CHECK: %[[POINTS:.*]] = elliptic_curve.point_set.from_elements %[[AFFINE1]], %[[AFFINE1]], %[[AFFINE1]] : [[TAF:.*]]
264-
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
265-
// CHECK: %[[IDX1:.*]] = arith.constant 1 : index
266-
%idx1 = arith.constant 1 : index
267-
// CHECK: %[[POINT1:.*]] = elliptic_curve.point_set.extract %[[POINTS]], %[[IDX1]] : [[TAF]] -> ![[AF]]
268-
%point1 = elliptic_curve.point_set.extract %points, %idx1 : tensor<3x!affine> -> !affine
269-
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.double %[[POINT1]] : ![[AF]] -> ![[JA:.*]]
270-
%doubled = elliptic_curve.double %point1 : !affine -> !jacobian
271-
return
272-
}
273-
274254
// CHECK-LABEL: @test_msm
275255
func.func @test_msm() {
276256
// CHECK: %[[VAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
@@ -282,8 +262,8 @@ func.func @test_msm() {
282262
%scalars = tensor.from_elements %var1, %var5, %var5 : tensor<3x!PF>
283263
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point %[[VAR1]], %[[VAR5]] : ![[PF]] -> ![[AF:.*]]
284264
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
285-
// CHECK: %[[POINTS:.*]] = elliptic_curve.point_set.from_elements %[[AFFINE1]], %[[AFFINE1]], %[[AFFINE1]] : [[TAF:.*]]
286-
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
265+
// CHECK: %[[POINTS:.*]] = tensor.from_elements %[[AFFINE1]], %[[AFFINE1]], %[[AFFINE1]] : [[TAF:.*]]
266+
%points = tensor.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
287267
// CHECK: %[[MSM_RESULT:.*]] = elliptic_curve.msm %[[SCALARS]], %[[POINTS]] : [[TPF]], [[TAF]] -> ![[JA:.*]]
288268
%msm_result = elliptic_curve.msm %scalars, %points : tensor<3x!PF>, tensor<3x!affine> -> !jacobian
289269
return

tests/Dialect/EllipticCurve/elliptic_curve_to_field.mlir

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,8 @@ func.func @test_intialization_and_conversion() {
2626
%var8 = field.pf.constant 8 : !PF
2727

2828
// CHECK-NOT: elliptic_curve.point
29-
// CHECK: %[[AFFINE1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]] : tensor<2x![[PF]]>
3029
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
31-
// CHECK-NOT: elliptic_curve.point
32-
// CHECK: %[[JACOBIAN1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]], %[[VAR2]] : tensor<3x![[PF]]>
3330
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
34-
// CHECK-NOT: elliptic_curve.point
35-
// CHECK: %[[XYZZ1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]], %[[VAR4]], %[[VAR8]] : tensor<4x![[PF]]>
3631
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
3732

3833
// CHECK-NOT: elliptic_curve.convert_point_type
@@ -177,29 +172,13 @@ func.func @test_scalar_mul() {
177172
return
178173
}
179174

180-
// CHECK-LABEL: @test_point_set
181-
func.func @test_point_set() {
182-
%var1 = field.pf.constant 1 : !PF
183-
%var2 = field.pf.constant 2 : !PF
184-
%var4 = field.pf.constant 4 : !PF
185-
%var5 = field.pf.constant 5 : !PF
186-
%var8 = field.pf.constant 8 : !PF
187-
188-
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
189-
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
190-
%idx1 = arith.constant 1 : index
191-
%point1 = elliptic_curve.point_set.extract %points, %idx1 : tensor<3x!affine> -> !affine
192-
%doubled = elliptic_curve.double %point1 : !affine -> !jacobian
193-
return
194-
}
195-
196175
func.func @test_msm() {
197176
%var1 = field.pf.constant 1 : !PF
198177
%var5 = field.pf.constant 5 : !PF
199178

200179
%scalars = tensor.from_elements %var1, %var5, %var5 : tensor<3x!PF>
201180
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
202-
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
181+
%points = tensor.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
203182
%msm_result = elliptic_curve.msm %scalars, %points : tensor<3x!PF>, tensor<3x!affine> -> !jacobian
204183
return
205184
}

tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,63 +35,72 @@ func.func @test_ops_in_order() {
3535
%jacobian1 = elliptic_curve.point %var5, %var3, %var2 : !PF -> !jacobian
3636

3737
%jacobian2 = elliptic_curve.add %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
38-
%extract_point1 = elliptic_curve.extract %jacobian2 : !jacobian -> tensor<3x!PF>
38+
%extract_point1x, %extract_point1y, %extract_point1z = elliptic_curve.extract %jacobian2 : !jacobian -> !PF, !PF, !PF
39+
%extract_point1 = tensor.from_elements %extract_point1x, %extract_point1y, %extract_point1z : tensor<3x!PF>
3940
%extract1 = field.pf.extract %extract_point1 : tensor<3x!PF> -> tensor<3xi32>
4041
%1 = bufferization.to_memref %extract1 : tensor<3xi32> to memref<3xi32>
4142
%U1 = memref.cast %1 : memref<3xi32> to memref<*xi32>
4243
func.call @printMemrefI32(%U1) : (memref<*xi32>) -> ()
4344

4445
%jacobian3 = elliptic_curve.sub %affine1, %jacobian2 : !affine, !jacobian -> !jacobian
45-
%extract_point2 = elliptic_curve.extract %jacobian3 : !jacobian -> tensor<3x!PF>
46+
%extract_point2x, %extract_point2y, %extract_point2z = elliptic_curve.extract %jacobian3 : !jacobian -> !PF, !PF, !PF
47+
%extract_point2 = tensor.from_elements %extract_point2x, %extract_point2y, %extract_point2z : tensor<3x!PF>
4648
%extract2 = field.pf.extract %extract_point2 : tensor<3x!PF> -> tensor<3xi32>
4749
%2 = bufferization.to_memref %extract2 : tensor<3xi32> to memref<3xi32>
4850
%U2 = memref.cast %2 : memref<3xi32> to memref<*xi32>
4951
func.call @printMemrefI32(%U2) : (memref<*xi32>) -> ()
5052

5153
%jacobian4 = elliptic_curve.negate %jacobian3 : !jacobian
52-
%extract_point3 = elliptic_curve.extract %jacobian4 : !jacobian -> tensor<3x!PF>
54+
%extract_point3x, %extract_point3y, %extract_point3z = elliptic_curve.extract %jacobian4 : !jacobian -> !PF, !PF, !PF
55+
%extract_point3 = tensor.from_elements %extract_point3x, %extract_point3y, %extract_point3z : tensor<3x!PF>
5356
%extract3 = field.pf.extract %extract_point3 : tensor<3x!PF> -> tensor<3xi32>
5457
%3 = bufferization.to_memref %extract3 : tensor<3xi32> to memref<3xi32>
5558
%U3 = memref.cast %3 : memref<3xi32> to memref<*xi32>
5659
func.call @printMemrefI32(%U3) : (memref<*xi32>) -> ()
5760

5861
%jacobian5 = elliptic_curve.double %jacobian4 : !jacobian -> !jacobian
59-
%extract_point4 = elliptic_curve.extract %jacobian5 : !jacobian -> tensor<3x!PF>
62+
%extract_point4x, %extract_point4y, %extract_point4z = elliptic_curve.extract %jacobian5 : !jacobian -> !PF, !PF, !PF
63+
%extract_point4 = tensor.from_elements %extract_point4x, %extract_point4y, %extract_point4z : tensor<3x!PF>
6064
%extract4 = field.pf.extract %extract_point4 : tensor<3x!PF> -> tensor<3xi32>
6165
%4 = bufferization.to_memref %extract4 : tensor<3xi32> to memref<3xi32>
6266
%U4 = memref.cast %4 : memref<3xi32> to memref<*xi32>
6367
func.call @printMemrefI32(%U4) : (memref<*xi32>) -> ()
6468

6569
%xyzz1 = elliptic_curve.convert_point_type %jacobian5 : !jacobian -> !xyzz
66-
%extract_point5 = elliptic_curve.extract %xyzz1 : !xyzz -> tensor<4x!PF>
70+
%extract_point5x, %extract_point5y, %extract_point5zz, %extract_point5zzz = elliptic_curve.extract %xyzz1 : !xyzz -> !PF, !PF, !PF, !PF
71+
%extract_point5 = tensor.from_elements %extract_point5x, %extract_point5y, %extract_point5zz, %extract_point5zzz : tensor<4x!PF>
6772
%extract5 = field.pf.extract %extract_point5 : tensor<4x!PF> -> tensor<4xi32>
6873
%5 = bufferization.to_memref %extract5 : tensor<4xi32> to memref<4xi32>
6974
%U5 = memref.cast %5 : memref<4xi32> to memref<*xi32>
7075
func.call @printMemrefI32(%U5) : (memref<*xi32>) -> ()
7176

7277
%affine2 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !affine
73-
%extract_point6 = elliptic_curve.extract %affine2 : !affine -> tensor<2x!PF>
78+
%extract_point6x, %extract_point6y = elliptic_curve.extract %affine2 : !affine -> !PF, !PF
79+
%extract_point6 = tensor.from_elements %extract_point6x, %extract_point6y : tensor<2x!PF>
7480
%extract6 = field.pf.extract %extract_point6 : tensor<2x!PF> -> tensor<2xi32>
7581
%6 = bufferization.to_memref %extract6 : tensor<2xi32> to memref<2xi32>
7682
%U6 = memref.cast %6 : memref<2xi32> to memref<*xi32>
7783
func.call @printMemrefI32(%U6) : (memref<*xi32>) -> ()
7884

7985
%jacobian6 = elliptic_curve.scalar_mul %var7, %affine2 : !PF, !affine -> !jacobian
80-
%extract_point7 = elliptic_curve.extract %jacobian6 : !jacobian -> tensor<3x!PF>
86+
%extract_point7x, %extract_point7y, %extract_point7z = elliptic_curve.extract %jacobian6 : !jacobian -> !PF, !PF, !PF
87+
%extract_point7 = tensor.from_elements %extract_point7x, %extract_point7y, %extract_point7z : tensor<3x!PF>
8188
%extract7 = field.pf.extract %extract_point7 : tensor<3x!PF> -> tensor<3xi32>
8289
%7 = bufferization.to_memref %extract7 : tensor<3xi32> to memref<3xi32>
8390
%U7 = memref.cast %7 : memref<3xi32> to memref<*xi32>
8491
func.call @printMemrefI32(%U7) : (memref<*xi32>) -> ()
8592

8693
%affine3 = elliptic_curve.convert_point_type %jacobian6 : !jacobian -> !affine
87-
%extract_point8 = elliptic_curve.extract %affine3 : !affine -> tensor<2x!PF>
94+
%extract_point8x, %extract_point8y = elliptic_curve.extract %affine3 : !affine -> !PF, !PF
95+
%extract_point8 = tensor.from_elements %extract_point8x, %extract_point8y : tensor<2x!PF>
8896
%extract8 = field.pf.extract %extract_point8 : tensor<2x!PF> -> tensor<2xi32>
8997
%8 = bufferization.to_memref %extract8 : tensor<2xi32> to memref<2xi32>
9098
%U8 = memref.cast %8 : memref<2xi32> to memref<*xi32>
9199
func.call @printMemrefI32(%U8) : (memref<*xi32>) -> ()
92100

93101
%xyzz2 = elliptic_curve.add %affine3, %xyzz1 : !affine, !xyzz -> !xyzz
94-
%extract_point9 = elliptic_curve.extract %xyzz2 : !xyzz -> tensor<4x!PF>
102+
%extract_point9x, %extract_point9y, %extract_point9zz, %extract_point9zzz = elliptic_curve.extract %xyzz2 : !xyzz -> !PF, !PF, !PF, !PF
103+
%extract_point9 = tensor.from_elements %extract_point9x, %extract_point9y, %extract_point9zz, %extract_point9zzz : tensor<4x!PF>
95104
%extract9 = field.pf.extract %extract_point9 : tensor<4x!PF> -> tensor<4xi32>
96105
%9 = bufferization.to_memref %extract9 : tensor<4xi32> to memref<4xi32>
97106
%U9 = memref.cast %9 : memref<4xi32> to memref<*xi32>
@@ -134,19 +143,21 @@ func.func @test_msm() {
134143
%add2 = elliptic_curve.add %scalar_mul3, %scalar_mul4 : !jacobian, !jacobian -> !jacobian
135144
%msm_true = elliptic_curve.add %add1, %add2 : !jacobian, !jacobian -> !jacobian
136145

137-
%extract_point = elliptic_curve.extract %msm_true : !jacobian -> tensor<3x!PF>
146+
%extract_point_x, %extract_point_y, %extract_point_z = elliptic_curve.extract %msm_true : !jacobian -> !PF, !PF, !PF
147+
%extract_point = tensor.from_elements %extract_point_x, %extract_point_y, %extract_point_z : tensor<3x!PF>
138148
%extract = field.pf.extract %extract_point : tensor<3x!PF> -> tensor<3xi32>
139149
%mem = bufferization.to_memref %extract : tensor<3xi32> to memref<3xi32>
140150
%U = memref.cast %mem : memref<3xi32> to memref<*xi32>
141151
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
142152

143153
// RUNNING MSM
144154
%scalars = tensor.from_elements %var5, %var7, %var3, %var2 : tensor<4x!PF>
145-
%points = elliptic_curve.point_set.from_elements %jacobian1, %jacobian2, %jacobian3, %jacobian4 : tensor<4x!jacobian>
155+
%points = tensor.from_elements %jacobian1, %jacobian2, %jacobian3, %jacobian4 : tensor<4x!jacobian>
146156
%msm_test = elliptic_curve.msm %scalars , %points: tensor<4x!PF>, tensor<4x!jacobian> -> !jacobian
147157

148-
%extract_point1 = elliptic_curve.extract %msm_test : !jacobian -> tensor<3x!PF>
149-
%extract1 = field.pf.extract %extract_point1 : tensor<3x!PF> -> tensor<3xi32>
158+
%extract_point1x, %extract_point1y, %extract_point1z = elliptic_curve.extract %msm_test : !jacobian -> !PF, !PF, !PF
159+
%extract_point1 = tensor.from_elements %extract_point1x, %extract_point1y, %extract_point1z : tensor<3x!PF>
160+
%extract1 = field.pf.extract %extract_point : tensor<3x!PF> -> tensor<3xi32>
150161
%mem1 = bufferization.to_memref %extract1 : tensor<3xi32> to memref<3xi32>
151162
%U1 = memref.cast %mem1 : memref<3xi32> to memref<*xi32>
152163
func.call @printMemrefI32(%U1) : (memref<*xi32>) -> ()

0 commit comments

Comments
 (0)