Skip to content

Commit e65a9ff

Browse files
authored
Merge pull request #12 from a41-official/feat/add-elliptic-curve-dialect
feat: add elliptic curve dialect
2 parents 9049ffe + f624f9f commit e65a9ff

21 files changed

+944
-11
lines changed

tests/Dialect/Arith/arith_to_mod_arith.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: zkir-opt --arith-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
22

33
// CHECK-LABEL: @test_lower_add
4-
// CHECK-SAME: (%[[LHS:.*]]: !Z2147483648_i33_, %[[RHS:.*]]: !Z2147483648_i33_) -> [[T:.*]] {
4+
// CHECK-SAME: (%[[LHS:.*]]: !z2147483648_i33_, %[[RHS:.*]]: !z2147483648_i33_) -> [[T:.*]] {
55
func.func @test_lower_add(%lhs : i32, %rhs : i32) -> i32 {
66
// CHECK: %[[ADD:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : [[T]]
77
// CHECK: return %[[ADD:.*]] : [[T]]
@@ -10,7 +10,7 @@ func.func @test_lower_add(%lhs : i32, %rhs : i32) -> i32 {
1010
}
1111

1212
// CHECK-LABEL: @test_lower_add_vec
13-
// CHECK-SAME: (%[[LHS:.*]]: tensor<4x!Z2147483648_i33_>, %[[RHS:.*]]: tensor<4x!Z2147483648_i33_>) -> [[T:.*]] {
13+
// CHECK-SAME: (%[[LHS:.*]]: tensor<4x!z2147483648_i33_>, %[[RHS:.*]]: tensor<4x!z2147483648_i33_>) -> [[T:.*]] {
1414
func.func @test_lower_add_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> {
1515
// CHECK: %[[ADD:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : [[T]]
1616
// CHECK: return %[[ADD:.*]] : [[T]]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("//bazel:lit.bzl", "glob_lit_tests")
2+
3+
package(
4+
licenses = ["notice"],
5+
)
6+
7+
glob_lit_tests(
8+
name = "all_tests",
9+
data = ["//tests:test_utilities"],
10+
driver = "//tests:run_lit.sh",
11+
test_file_exts = ["mlir"],
12+
)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
// RUN: zkir-opt --split-input-file %s | FileCheck %s --enable-var-scope
2+
3+
!PF = !field.pf<35:i32>
4+
5+
#1 = #field.pf_elem<1:i32> : !PF
6+
#2 = #field.pf_elem<2:i32> : !PF
7+
#3 = #field.pf_elem<3:i32> : !PF
8+
#4 = #field.pf_elem<4:i32> : !PF
9+
10+
#curve = #elliptic_curve.sw<#1, #2, (#3, #4)>
11+
!affine = !elliptic_curve.affine<#curve>
12+
!jacobian = !elliptic_curve.jacobian<#curve>
13+
!xyzz = !elliptic_curve.xyzz<#curve>
14+
15+
// CHECK-LABEL: @test_intialization_and_conversion
16+
func.func @test_intialization_and_conversion() {
17+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
18+
%affine1 = elliptic_curve.point 1, 5 : !affine
19+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA:.*]]
20+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
21+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
22+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
23+
24+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.convert_point_type %[[AFFINE1]] : ![[AF]] -> ![[JA]]
25+
%jacobian2 = elliptic_curve.convert_point_type %affine1 : !affine -> !jacobian
26+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.convert_point_type %[[AFFINE1]] : ![[AF]] -> ![[XY]]
27+
%xyzz2 = elliptic_curve.convert_point_type %affine1 : !affine -> !xyzz
28+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.convert_point_type %[[JACOBIAN1]] : ![[JA]] -> ![[AF]]
29+
%affine2 = elliptic_curve.convert_point_type %jacobian1 : !jacobian -> !affine
30+
// CHECK: %[[XYZZ3:.*]] = elliptic_curve.convert_point_type %[[JACOBIAN1]] : ![[JA]] -> ![[XY]]
31+
%xyzz3 = elliptic_curve.convert_point_type %jacobian1 : !jacobian -> !xyzz
32+
// CHECK: %[[AFFINE3:.*]] = elliptic_curve.convert_point_type %[[XYZZ1]] : ![[XY]] -> ![[AF]]
33+
%affine3 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !affine
34+
// CHECK: %[[JACOBIAN3:.*]] = elliptic_curve.convert_point_type %[[XYZZ1]] : ![[XY]] -> ![[JA]]
35+
%jacobian3 = elliptic_curve.convert_point_type %xyzz1 : !xyzz -> !jacobian
36+
return
37+
}
38+
39+
// CHECK-LABEL: @test_add
40+
func.func @test_add() {
41+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
42+
%affine1 = elliptic_curve.point 1, 5 : !affine
43+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.point
44+
%affine2 = elliptic_curve.point 3, 6 : !affine
45+
46+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA:.*]]
47+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
48+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.point
49+
%jacobian2 = elliptic_curve.point 3, 6, 1 : !jacobian
50+
51+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
52+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
53+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.point
54+
%xyzz2 = elliptic_curve.point 3, 6, 1, 2 : !xyzz
55+
56+
// affine, affine -> jacobian
57+
// CHECK: %[[AFFINE3:.*]] = elliptic_curve.add %[[AFFINE1]], %[[AFFINE2]] : ![[AF]], ![[AF]] -> ![[JA]]
58+
%affine3 = elliptic_curve.add %affine1, %affine2 : !affine, !affine -> !jacobian
59+
// affine, jacobian -> jacobian
60+
// CHECK: %[[JACOBIAN3:.*]] = elliptic_curve.add %[[AFFINE1]], %[[JACOBIAN1]] : ![[AF]], ![[JA]] -> ![[JA]]
61+
%jacobian3 = elliptic_curve.add %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
62+
// CHECK: %[[JACOBIAN10:.*]] = elliptic_curve.add %[[JACOBIAN1]], %[[AFFINE1]] : ![[JA]], ![[AF]] -> ![[JA]]
63+
%jacobian10 = elliptic_curve.add %jacobian1, %affine1 : !jacobian, !affine -> !jacobian
64+
// affine, xyzz -> xyzz
65+
// CHECK: %[[XYZZ3:.*]] = elliptic_curve.add %[[AFFINE1]], %[[XYZZ1]] : ![[AF]], ![[XY]] -> ![[XY]]
66+
%xyzz3 = elliptic_curve.add %affine1, %xyzz1 : !affine, !xyzz -> !xyzz
67+
// jacobian, jacobian -> jacobian
68+
// CHECK: %[[JACOBIAN4:.*]] = elliptic_curve.add %[[JACOBIAN1]], %[[JACOBIAN2]] : ![[JA]], ![[JA]] -> ![[JA]]
69+
%jacobian4 = elliptic_curve.add %jacobian1, %jacobian2 : !jacobian, !jacobian -> !jacobian
70+
// xyzz, xyzz -> xyzz
71+
// CHECK: %[[XYZZ4:.*]] = elliptic_curve.add %[[XYZZ1]], %[[XYZZ2]] : ![[XY]], ![[XY]] -> ![[XY]]
72+
%xyzz4 = elliptic_curve.add %xyzz1, %xyzz2 : !xyzz, !xyzz -> !xyzz
73+
return
74+
}
75+
76+
// CHECK-LABEL: @test_sub
77+
func.func @test_sub() {
78+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
79+
%affine1 = elliptic_curve.point 1, 5 : !affine
80+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.point
81+
%affine2 = elliptic_curve.point 3, 6 : !affine
82+
83+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA:.*]]
84+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
85+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.point
86+
%jacobian2 = elliptic_curve.point 3, 6, 1 : !jacobian
87+
88+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
89+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
90+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.point
91+
%xyzz2 = elliptic_curve.point 3, 6, 1, 2 : !xyzz
92+
93+
// affine, affine -> jacobian
94+
// CHECK: %[[AFFINE3:.*]] = elliptic_curve.sub %[[AFFINE1]], %[[AFFINE2]] : ![[AF]], ![[AF]] -> ![[JA]]
95+
%affine3 = elliptic_curve.sub %affine1, %affine2 : !affine, !affine -> !jacobian
96+
// affine, jacobian -> jacobian
97+
// CHECK: %[[JACOBIAN3:.*]] = elliptic_curve.sub %[[AFFINE1]], %[[JACOBIAN1]] : ![[AF]], ![[JA]] -> ![[JA]]
98+
%jacobian3 = elliptic_curve.sub %affine1, %jacobian1 : !affine, !jacobian -> !jacobian
99+
// CHECK: %[[JACOBIAN10:.*]] = elliptic_curve.sub %[[JACOBIAN1]], %[[AFFINE1]] : ![[JA]], ![[AF]] -> ![[JA]]
100+
%jacobian10 = elliptic_curve.sub %jacobian1, %affine1 : !jacobian, !affine -> !jacobian
101+
// affine, xyzz -> xyzz
102+
// CHECK: %[[XYZZ3:.*]] = elliptic_curve.sub %[[AFFINE1]], %[[XYZZ1]] : ![[AF]], ![[XY]] -> ![[XY]]
103+
%xyzz3 = elliptic_curve.sub %affine1, %xyzz1 : !affine, !xyzz -> !xyzz
104+
// jacobian, jacobian -> jacobian
105+
// CHECK: %[[JACOBIAN4:.*]] = elliptic_curve.sub %[[JACOBIAN1]], %[[JACOBIAN2]] : ![[JA]], ![[JA]] -> ![[JA]]
106+
%jacobian4 = elliptic_curve.sub %jacobian1, %jacobian2 : !jacobian, !jacobian -> !jacobian
107+
// xyzz, xyzz -> xyzz
108+
// CHECK: %[[XYZZ4:.*]] = elliptic_curve.sub %[[XYZZ1]], %[[XYZZ2]] : ![[XY]], ![[XY]] -> ![[XY]]
109+
%xyzz4 = elliptic_curve.sub %xyzz1, %xyzz2 : !xyzz, !xyzz -> !xyzz
110+
return
111+
}
112+
113+
// CHECK-LABEL: @test_negation
114+
func.func @test_negation() {
115+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
116+
%affine1 = elliptic_curve.point 1, 5 : !affine
117+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.neg %[[AFFINE1]] : ![[AF]]
118+
%affine2 = elliptic_curve.neg %affine1 : !affine
119+
120+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA:.*]]
121+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
122+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.neg %[[JACOBIAN1]] : ![[JA]]
123+
%jacobian2 = elliptic_curve.neg %jacobian1 : !jacobian
124+
125+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
126+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
127+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.neg %[[XYZZ1]] : ![[XY]]
128+
%xyzz2 = elliptic_curve.neg %xyzz1 : !xyzz
129+
return
130+
}
131+
132+
// CHECK-LABEL: @test_double
133+
func.func @test_double() {
134+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
135+
%affine1 = elliptic_curve.point 1, 5 : !affine
136+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.dbl %[[AFFINE1]] : ![[AF]] -> ![[JA:.*]]
137+
%affine2 = elliptic_curve.dbl %affine1 : !affine -> !jacobian
138+
139+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA]]
140+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
141+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.dbl %[[JACOBIAN1]] : ![[JA]] -> ![[JA]]
142+
%jacobian2 = elliptic_curve.dbl %jacobian1 : !jacobian -> !jacobian
143+
144+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
145+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
146+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.dbl %[[XYZZ1]] : ![[XY]] -> ![[XY]]
147+
%xyzz2 = elliptic_curve.dbl %xyzz1 : !xyzz -> !xyzz
148+
return
149+
}
150+
151+
// CHECK-LABEL: @test_scalar_mul
152+
func.func @test_scalar_mul() {
153+
// CHECK: %[[SCALAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
154+
%scalar1 = field.pf.constant 1 : !PF
155+
// CHECK: %[[SCALAR2:.*]] = field.pf.constant 7 : ![[PF]]
156+
%scalar2 = field.pf.constant 7 : !PF
157+
158+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point [[TMP1:.*]] : ![[AF:.*]]
159+
%affine1 = elliptic_curve.point 1, 5 : !affine
160+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.scalar_mul %[[AFFINE1]], %[[SCALAR1]] : ![[AF]], ![[PF]] -> ![[JA:.*]]
161+
%affine2 = elliptic_curve.scalar_mul %affine1, %scalar1 : !affine, !PF -> !jacobian
162+
// CHECK: %[[JACOBIAN4:.*]] = elliptic_curve.scalar_mul %[[AFFINE1]], %[[SCALAR2]] : ![[AF]], ![[PF]] -> ![[JA]]
163+
%jacobian4 = elliptic_curve.scalar_mul %affine1, %scalar2 : !affine, !PF -> !jacobian
164+
165+
// CHECK: %[[JACOBIAN1:.*]] = elliptic_curve.point [[TMP2:.*]] : ![[JA]]
166+
%jacobian1 = elliptic_curve.point 1, 5, 2 : !jacobian
167+
// CHECK: %[[JACOBIAN2:.*]] = elliptic_curve.scalar_mul %[[JACOBIAN1]], %[[SCALAR1]] : ![[JA]], ![[PF]] -> ![[JA]]
168+
%jacobian2 = elliptic_curve.scalar_mul %jacobian1, %scalar1 : !jacobian, !PF -> !jacobian
169+
// CHECK: %[[JACOBIAN3:.*]] = elliptic_curve.scalar_mul %[[JACOBIAN1]], %[[SCALAR2]] : ![[JA]], ![[PF]] -> ![[JA]]
170+
%jacobian3 = elliptic_curve.scalar_mul %jacobian1, %scalar2 : !jacobian, !PF -> !jacobian
171+
172+
// CHECK: %[[XYZZ1:.*]] = elliptic_curve.point [[TMP3:.*]] : ![[XY:.*]]
173+
%xyzz1 = elliptic_curve.point 1, 5, 4, 3 : !xyzz
174+
// CHECK: %[[XYZZ2:.*]] = elliptic_curve.scalar_mul %[[XYZZ1]], %[[SCALAR1]] : ![[XY]], ![[PF]] -> ![[XY]]
175+
%xyzz2 = elliptic_curve.scalar_mul %xyzz1, %scalar1 : !xyzz, !PF -> !xyzz
176+
// CHECK: %[[XYZZ3:.*]] = elliptic_curve.scalar_mul %[[XYZZ1]], %[[SCALAR2]] : ![[XY]], ![[PF]] -> ![[XY]]
177+
%xyzz3 = elliptic_curve.scalar_mul %xyzz1, %scalar2 : !xyzz, !PF -> !xyzz
178+
return
179+
}

tests/Dialect/Field/prime_field_to_mod_arith.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func.func @test_lower_add() -> !PF1 {
8585
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
8686
func.func @test_lower_add_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
8787
// CHECK-NOT: field.pf.add
88-
// CHECK: %[[RES:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : tensor<4x!Z3_i32_>
88+
// CHECK: %[[RES:.*]] = mod_arith.add %[[LHS]], %[[RHS]] : tensor<4x!z3_i32_>
8989
%res = field.pf.add %lhs, %rhs : !PFv
9090
// CHECK: return %[[RES]] : [[T]]
9191
return %res : !PFv

tools/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_binary(
88
name = "zkir-lsp",
99
srcs = ["zkir-lsp.cpp"],
1010
deps = [
11+
"//zkir/Dialect/EllipticCurve/IR:EllipticCurve",
1112
"//zkir/Dialect/Field/IR:Field",
1213
"//zkir/Dialect/ModArith/IR:ModArith",
1314
"//zkir/Dialect/Poly/IR:Poly",
@@ -22,6 +23,7 @@ cc_binary(
2223
srcs = ["zkir-opt.cpp"],
2324
deps = [
2425
"//zkir/Dialect/Arith/Conversions/ArithToModArith",
26+
"//zkir/Dialect/EllipticCurve/IR:EllipticCurve",
2527
"//zkir/Dialect/Field/Conversions/FieldToModArith",
2628
"//zkir/Dialect/Field/IR:Field",
2729
"//zkir/Dialect/ModArith/Conversions/ModArithToArith",

tools/zkir-lsp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "mlir/InitAllExtensions.h"
33
#include "mlir/InitAllPasses.h"
44
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
5+
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveDialect.h"
56
#include "zkir/Dialect/Field/IR/FieldDialect.h"
67
#include "zkir/Dialect/ModArith/IR/ModArithDialect.h"
78
#include "zkir/Dialect/Poly/IR/PolyDialect.h"
@@ -11,6 +12,7 @@ int main(int argc, char **argv) {
1112
registry.insert<mlir::zkir::mod_arith::ModArithDialect>();
1213
registry.insert<mlir::zkir::field::FieldDialect>();
1314
registry.insert<mlir::zkir::poly::PolyDialect>();
15+
registry.insert<mlir::zkir::elliptic_curve::EllipticCurveDialect>();
1416
mlir::registerAllDialects(registry);
1517
mlir::registerAllExtensions(registry);
1618
return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));

tools/zkir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/InitAllPasses.h"
55
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
66
#include "zkir/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h"
7+
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveDialect.h"
78
#include "zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.h"
89
#include "zkir/Dialect/Field/IR/FieldDialect.h"
910
#include "zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.h"
@@ -17,6 +18,7 @@ int main(int argc, char **argv) {
1718
registry.insert<mlir::zkir::mod_arith::ModArithDialect>();
1819
registry.insert<mlir::zkir::field::FieldDialect>();
1920
registry.insert<mlir::zkir::poly::PolyDialect>();
21+
registry.insert<mlir::zkir::elliptic_curve::EllipticCurveDialect>();
2022
mlir::registerAllDialects(registry);
2123
mlir::registerAllExtensions(registry);
2224

0 commit comments

Comments
 (0)