Skip to content

Commit dc44ad9

Browse files
authored
Merge pull request #17 from a41-official/feat/elliptic-curve-lowering
feat: elliptic curve lowering
2 parents 7cf683a + af2f377 commit dc44ad9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2418
-308
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@ jobs:
2525

2626
- name: Setup Bazelisk
2727
uses: bazel-contrib/[email protected]
28-
29-
- name: Mount Bazel Cache
30-
uses: actions/cache@v4
3128
with:
32-
path: "~/.cache/bazel"
33-
key: ${{ runner.os }}-zkir_bazelbuild-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE', 'WORKSPACE.bazel', 'MODULE.bazel') }}
34-
restore-keys: |
35-
${{ runner.os }}-zkir_bazelbuild-
29+
bazelisk-cache: true
30+
disk-cache: ${{ runner.os }}-zkir_bazelbuild
31+
repository-cache: false
32+
external-cache: false
3633

3734
- name: Run `bazel build`
3835
run: |

tests/Dialect/EllipticCurve/elliptic_curve_syntax.mlir

Lines changed: 172 additions & 99 deletions
Large diffs are not rendered by default.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// RUN: zkir-opt -elliptic-curve-to-field --split-input-file %s | FileCheck %s --enable-var-scope
2+
3+
!PF = !field.pf<35:i32>
4+
5+
#1 = #field.pf_elem<1:i32> : !PF
6+
#2 = #field.pf_elem<2:i32> : !PF
7+
#3 = #field.pf_elem<3:i32> : !PF
8+
#4 = #field.pf_elem<4:i32> : !PF
9+
10+
#curve = #elliptic_curve.sw<#1, #2, (#3, #4)>
11+
!affine = !elliptic_curve.affine<#curve>
12+
!jacobian = !elliptic_curve.jacobian<#curve>
13+
!xyzz = !elliptic_curve.xyzz<#curve>
14+
15+
// CHECK-LABEL: @test_intialization_and_conversion
16+
func.func @test_intialization_and_conversion() {
17+
// CHECK: %[[VAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
18+
%var1 = field.pf.constant 1 : !PF
19+
// CHECK: %[[VAR2:.*]] = field.pf.constant 2 : ![[PF]]
20+
%var2 = field.pf.constant 2 : !PF
21+
// CHECK: %[[VAR4:.*]] = field.pf.constant 4 : ![[PF]]
22+
%var4 = field.pf.constant 4 : !PF
23+
// CHECK: %[[VAR5:.*]] = field.pf.constant 5 : ![[PF]]
24+
%var5 = field.pf.constant 5 : !PF
25+
// CHECK: %[[VAR8:.*]] = field.pf.constant 8 : ![[PF]]
26+
%var8 = field.pf.constant 8 : !PF
27+
28+
// CHECK-NOT: elliptic_curve.point
29+
// CHECK: %[[AFFINE1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]] : tensor<2x![[PF]]>
30+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
31+
// CHECK-NOT: elliptic_curve.point
32+
// CHECK: %[[JACOBIAN1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]], %[[VAR2]] : tensor<3x![[PF]]>
33+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
34+
// CHECK-NOT: elliptic_curve.point
35+
// CHECK: %[[XYZZ1:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]], %[[VAR4]], %[[VAR8]] : tensor<4x![[PF]]>
36+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
37+
38+
// CHECK-NOT: elliptic_curve.convert_point_type
39+
%jacobian2 = elliptic_curve.convert_point_type %affine1 : !affine -> !jacobian
40+
%xyzz2 = elliptic_curve.convert_point_type %affine1 : !affine -> !xyzz
41+
%affine2 = elliptic_curve.convert_point_type %jacobian1 : !jacobian -> !affine
42+
%xyzz3 = elliptic_curve.convert_point_type %jacobian1 : !jacobian -> !xyzz
43+
%affine3 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !affine
44+
%jacobian3 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !jacobian
45+
return
46+
}
47+
48+
// CHECK-LABEL: @test_addition
49+
func.func @test_addition() {
50+
%var1 = field.pf.constant 1 : !PF
51+
%var2 = field.pf.constant 2 : !PF
52+
%var3 = field.pf.constant 3 : !PF
53+
%var4 = field.pf.constant 4 : !PF
54+
%var5 = field.pf.constant 5 : !PF
55+
%var6 = field.pf.constant 6 : !PF
56+
%var8 = field.pf.constant 8 : !PF
57+
58+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
59+
%affine2 = elliptic_curve.point %var3, %var6 : !PF -> !affine
60+
61+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
62+
%jacobian2 = elliptic_curve.point %var3, %var6, %var1 : !PF -> !jacobian
63+
64+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
65+
%xyzz2 = elliptic_curve.point %var3, %var6, %var1, %var1 : !PF -> !xyzz
66+
67+
// CHECK-NOT: elliptic_curve.add
68+
// affine, affine -> jacobian
69+
%affine3 = elliptic_curve.add %affine1, %affine2 : !affine, !affine -> !jacobian
70+
// affine, jacobian -> jacobian
71+
%jacobian3 = elliptic_curve.add %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
72+
%jacobian4 = elliptic_curve.add %jacobian1, %affine1 : !jacobian, !affine -> !jacobian
73+
// affine, xyzz -> xyzz
74+
%xyzz3 = elliptic_curve.add %affine1, %xyzz1 : !affine, !xyzz -> !xyzz
75+
%xyzz4 = elliptic_curve.add %xyzz1, %affine1 : !xyzz, !affine -> !xyzz
76+
// jacobian, jacobian -> jacobian
77+
%jacobian5 = elliptic_curve.add %jacobian1, %jacobian2 : !jacobian, !jacobian -> !jacobian
78+
// xyzz, xyzz -> xyzz
79+
%xyzz5 = elliptic_curve.add %xyzz1, %xyzz2 : !xyzz, !xyzz -> !xyzz
80+
return
81+
}
82+
83+
// CHECK-LABEL: @test_double
84+
func.func @test_double() {
85+
%var1 = field.pf.constant 1 : !PF
86+
%var2 = field.pf.constant 2 : !PF
87+
%var4 = field.pf.constant 4 : !PF
88+
%var5 = field.pf.constant 5 : !PF
89+
%var8 = field.pf.constant 8 : !PF
90+
91+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
92+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
93+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
94+
95+
// CHECK-NOT: elliptic_curve.double
96+
%affine2 = elliptic_curve.double %affine1 : !affine -> !jacobian
97+
%jacobian2 = elliptic_curve.double %jacobian1 : !jacobian -> !jacobian
98+
%xyzz2 = elliptic_curve.double %xyzz1 : !xyzz -> !xyzz
99+
return
100+
}
101+
102+
// CHECK-LABEL: @test_negation
103+
func.func @test_negation() {
104+
%var1 = field.pf.constant 1 : !PF
105+
%var2 = field.pf.constant 2 : !PF
106+
%var4 = field.pf.constant 4 : !PF
107+
%var5 = field.pf.constant 5 : !PF
108+
%var8 = field.pf.constant 8 : !PF
109+
110+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
111+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
112+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
113+
114+
// CHECK-NOT: elliptic_curve.negate
115+
%affine2 = elliptic_curve.negate %affine1 : !affine
116+
%jacobian2 = elliptic_curve.negate %jacobian1 : !jacobian
117+
%xyzz2 = elliptic_curve.negate %xyzz1 : !xyzz
118+
return
119+
}
120+
121+
// CHECK-LABEL: @test_subtraction
122+
func.func @test_subtraction() {
123+
%var1 = field.pf.constant 1 : !PF
124+
%var2 = field.pf.constant 2 : !PF
125+
%var3 = field.pf.constant 3 : !PF
126+
%var4 = field.pf.constant 4 : !PF
127+
%var5 = field.pf.constant 5 : !PF
128+
%var6 = field.pf.constant 6 : !PF
129+
%var8 = field.pf.constant 8 : !PF
130+
131+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
132+
%affine2 = elliptic_curve.point %var3, %var6 : !PF -> !affine
133+
134+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
135+
%jacobian2 = elliptic_curve.point %var3, %var6, %var1 : !PF -> !jacobian
136+
137+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
138+
%xyzz2 = elliptic_curve.point %var3, %var6, %var1, %var1 : !PF -> !xyzz
139+
140+
// CHECK-NOT: elliptic_curve.sub
141+
// affine, affine -> jacobian
142+
%affine3 = elliptic_curve.sub %affine1, %affine2 : !affine, !affine -> !jacobian
143+
// affine, jacobian -> jacobian
144+
%jacobian3 = elliptic_curve.sub %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
145+
%jacobian4 = elliptic_curve.sub %jacobian1, %affine1 : !jacobian, !affine -> !jacobian
146+
// affine, xyzz -> xyzz
147+
%xyzz3 = elliptic_curve.sub %affine1, %xyzz1 : !affine, !xyzz -> !xyzz
148+
%xyzz4 = elliptic_curve.sub %xyzz1, %affine1 : !xyzz, !affine -> !xyzz
149+
// jacobian, jacobian -> jacobian
150+
%jacobian5 = elliptic_curve.sub %jacobian1, %jacobian2 : !jacobian, !jacobian -> !jacobian
151+
// xyzz, xyzz -> xyzz
152+
%xyzz5 = elliptic_curve.sub %xyzz1, %xyzz2 : !xyzz, !xyzz -> !xyzz
153+
return
154+
}
155+
156+
// CHECK-LABEL: @test_scalar_mul
157+
func.func @test_scalar_mul() {
158+
%var1 = field.pf.constant 1 : !PF
159+
%var2 = field.pf.constant 2 : !PF
160+
%var4 = field.pf.constant 4 : !PF
161+
%var5 = field.pf.constant 5 : !PF
162+
%var8 = field.pf.constant 8 : !PF
163+
164+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
165+
%jacobian1 = elliptic_curve.point %var1, %var5, %var2 : !PF -> !jacobian
166+
%xyzz1 = elliptic_curve.point %var1, %var5, %var4, %var8 : !PF -> !xyzz
167+
168+
// CHECK-NOT: elliptic_curve.scalar_mul
169+
%jacobian2 = elliptic_curve.scalar_mul %var1, %affine1 : !PF, !affine -> !jacobian
170+
%jacobian3 = elliptic_curve.scalar_mul %var8, %affine1 : !PF, !affine -> !jacobian
171+
172+
%jacobian4 = elliptic_curve.scalar_mul %var1, %jacobian1 : !PF, !jacobian -> !jacobian
173+
%jacobian5 = elliptic_curve.scalar_mul %var8, %jacobian1 : !PF, !jacobian -> !jacobian
174+
175+
%xyzz2 = elliptic_curve.scalar_mul %var1, %xyzz1 : !PF, !xyzz -> !xyzz
176+
%xyzz3 = elliptic_curve.scalar_mul %var8, %xyzz1 : !PF, !xyzz -> !xyzz
177+
return
178+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// RUN: zkir-opt %s -elliptic-curve-to-llvm \
2+
// RUN: | mlir-runner -e test_ops_in_order -entry-point-result=void \
3+
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
4+
// RUN: FileCheck %s --check-prefix=CHECK_TEST_OPS_IN_ORDER < %t
5+
6+
!PF = !field.pf<11:i32>
7+
8+
#1 = #field.pf_elem<1:i32> : !PF
9+
#2 = #field.pf_elem<2:i32> : !PF
10+
#3 = #field.pf_elem<3:i32> : !PF
11+
#4 = #field.pf_elem<4:i32> : !PF
12+
13+
#curve = #elliptic_curve.sw<#1, #2, (#3, #4)>
14+
!affine = !elliptic_curve.affine<#curve>
15+
!jacobian = !elliptic_curve.jacobian<#curve>
16+
!xyzz = !elliptic_curve.xyzz<#curve>
17+
18+
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }
19+
20+
// CHECK-LABEL: @test_ops_in_order
21+
func.func @test_ops_in_order() {
22+
%var1 = field.pf.constant 1 : !PF
23+
%var2 = field.pf.constant 2 : !PF
24+
%var3 = field.pf.constant 3 : !PF
25+
%var5 = field.pf.constant 5 : !PF
26+
%var7 = field.pf.constant 7 : !PF
27+
28+
%affine1 = elliptic_curve.point %var1, %var2 : !PF -> !affine
29+
%jacobian1 = elliptic_curve.point %var5, %var3, %var2 : !PF -> !jacobian
30+
31+
%jacobian2 = elliptic_curve.add %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
32+
%extract_point1 = elliptic_curve.extract %jacobian2 : !jacobian -> tensor<3x!PF>
33+
%extract1 = field.pf.extract %extract_point1 : tensor<3x!PF> -> tensor<3xi32>
34+
%1 = bufferization.to_memref %extract1 : tensor<3xi32> to memref<3xi32>
35+
%U1 = memref.cast %1 : memref<3xi32> to memref<*xi32>
36+
func.call @printMemrefI32(%U1) : (memref<*xi32>) -> ()
37+
38+
%jacobian3 = elliptic_curve.sub %affine1, %jacobian2 : !affine, !jacobian -> !jacobian
39+
%extract_point2 = elliptic_curve.extract %jacobian3 : !jacobian -> tensor<3x!PF>
40+
%extract2 = field.pf.extract %extract_point2 : tensor<3x!PF> -> tensor<3xi32>
41+
%2 = bufferization.to_memref %extract2 : tensor<3xi32> to memref<3xi32>
42+
%U2 = memref.cast %2 : memref<3xi32> to memref<*xi32>
43+
func.call @printMemrefI32(%U2) : (memref<*xi32>) -> ()
44+
45+
%jacobian4 = elliptic_curve.negate %jacobian3 : !jacobian
46+
%extract_point3 = elliptic_curve.extract %jacobian4 : !jacobian -> tensor<3x!PF>
47+
%extract3 = field.pf.extract %extract_point3 : tensor<3x!PF> -> tensor<3xi32>
48+
%3 = bufferization.to_memref %extract3 : tensor<3xi32> to memref<3xi32>
49+
%U3 = memref.cast %3 : memref<3xi32> to memref<*xi32>
50+
func.call @printMemrefI32(%U3) : (memref<*xi32>) -> ()
51+
52+
%jacobian5 = elliptic_curve.double %jacobian4 : !jacobian -> !jacobian
53+
%extract_point4 = elliptic_curve.extract %jacobian5 : !jacobian -> tensor<3x!PF>
54+
%extract4 = field.pf.extract %extract_point4 : tensor<3x!PF> -> tensor<3xi32>
55+
%4 = bufferization.to_memref %extract4 : tensor<3xi32> to memref<3xi32>
56+
%U4 = memref.cast %4 : memref<3xi32> to memref<*xi32>
57+
func.call @printMemrefI32(%U4) : (memref<*xi32>) -> ()
58+
59+
%xyzz1 = elliptic_curve.convert_point_type %jacobian5 : !jacobian -> !xyzz
60+
%extract_point5 = elliptic_curve.extract %xyzz1 : !xyzz -> tensor<4x!PF>
61+
%extract5 = field.pf.extract %extract_point5 : tensor<4x!PF> -> tensor<4xi32>
62+
%5 = bufferization.to_memref %extract5 : tensor<4xi32> to memref<4xi32>
63+
%U5 = memref.cast %5 : memref<4xi32> to memref<*xi32>
64+
func.call @printMemrefI32(%U5) : (memref<*xi32>) -> ()
65+
66+
%affine2 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !affine
67+
%extract_point6 = elliptic_curve.extract %affine2 : !affine -> tensor<2x!PF>
68+
%extract6 = field.pf.extract %extract_point6 : tensor<2x!PF> -> tensor<2xi32>
69+
%6 = bufferization.to_memref %extract6 : tensor<2xi32> to memref<2xi32>
70+
%U6 = memref.cast %6 : memref<2xi32> to memref<*xi32>
71+
func.call @printMemrefI32(%U6) : (memref<*xi32>) -> ()
72+
73+
%jacobian6 = elliptic_curve.scalar_mul %var7, %affine2 : !PF, !affine -> !jacobian
74+
%extract_point7 = elliptic_curve.extract %jacobian6 : !jacobian -> tensor<3x!PF>
75+
%extract7 = field.pf.extract %extract_point7 : tensor<3x!PF> -> tensor<3xi32>
76+
%7 = bufferization.to_memref %extract7 : tensor<3xi32> to memref<3xi32>
77+
%U7 = memref.cast %7 : memref<3xi32> to memref<*xi32>
78+
func.call @printMemrefI32(%U7) : (memref<*xi32>) -> ()
79+
80+
%affine3 = elliptic_curve.convert_point_type %jacobian6 : !jacobian -> !affine
81+
%extract_point8 = elliptic_curve.extract %affine3 : !affine -> tensor<2x!PF>
82+
%extract8 = field.pf.extract %extract_point8 : tensor<2x!PF> -> tensor<2xi32>
83+
%8 = bufferization.to_memref %extract8 : tensor<2xi32> to memref<2xi32>
84+
%U8 = memref.cast %8 : memref<2xi32> to memref<*xi32>
85+
func.call @printMemrefI32(%U8) : (memref<*xi32>) -> ()
86+
87+
%xyzz2 = elliptic_curve.add %affine3, %xyzz1 : !affine, !xyzz -> !xyzz
88+
%extract_point9 = elliptic_curve.extract %xyzz2 : !xyzz -> tensor<4x!PF>
89+
%extract9 = field.pf.extract %extract_point9 : tensor<4x!PF> -> tensor<4xi32>
90+
%9 = bufferization.to_memref %extract9 : tensor<4xi32> to memref<4xi32>
91+
%U9 = memref.cast %9 : memref<4xi32> to memref<*xi32>
92+
func.call @printMemrefI32(%U9) : (memref<*xi32>) -> ()
93+
return
94+
}
95+
96+
// CHECK_TEST_OPS_IN_ORDER: [2, 8, 7]
97+
// CHECK_TEST_OPS_IN_ORDER: [5, 3, 9]
98+
// CHECK_TEST_OPS_IN_ORDER: [5, 8, 9]
99+
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1]
100+
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1, 1]
101+
// CHECK_TEST_OPS_IN_ORDER: [1, 10]
102+
// CHECK_TEST_OPS_IN_ORDER: [0, 0, 0]
103+
// CHECK_TEST_OPS_IN_ORDER: [1, 1]
104+
// CHECK_TEST_OPS_IN_ORDER: [4, 3, 0, 0]

0 commit comments

Comments
 (0)