Skip to content

Commit 7c358dd

Browse files
authored
Merge pull request #20 from zk-rabbit/feat/naive-msm
feat: impl naive msm
2 parents 40f7aa3 + 9b05917 commit 7c358dd

File tree

15 files changed

+551
-165
lines changed

15 files changed

+551
-165
lines changed

tests/Dialect/EllipticCurve/elliptic_curve_syntax.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,41 @@ func.func @test_scalar_mul() {
250250
%xyzz3 = elliptic_curve.scalar_mul %var8, %xyzz1 : !PF, !xyzz -> !xyzz
251251
return
252252
}
253+
254+
// CHECK-LABEL: @test_point_set
255+
func.func @test_point_set() {
256+
// CHECK: %[[VAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
257+
%var1 = field.pf.constant 1 : !PF
258+
// CHECK: %[[VAR5:.*]] = field.pf.constant 5 : ![[PF]]
259+
%var5 = field.pf.constant 5 : !PF
260+
261+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point %[[VAR1]], %[[VAR5]] : ![[PF]] -> ![[AF:.*]]
262+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
263+
// CHECK: %[[POINTS:.*]] = elliptic_curve.point_set.from_elements %[[AFFINE1]], %[[AFFINE1]], %[[AFFINE1]] : [[TAF:.*]]
264+
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
265+
// CHECK: %[[IDX1:.*]] = arith.constant 1 : index
266+
%idx1 = arith.constant 1 : index
267+
// CHECK: %[[POINT1:.*]] = elliptic_curve.point_set.extract %[[POINTS]], %[[IDX1]] : [[TAF]] -> ![[AF]]
268+
%point1 = elliptic_curve.point_set.extract %points, %idx1 : tensor<3x!affine> -> !affine
269+
// CHECK: %[[AFFINE2:.*]] = elliptic_curve.double %[[POINT1]] : ![[AF]] -> ![[JA:.*]]
270+
%doubled = elliptic_curve.double %point1 : !affine -> !jacobian
271+
return
272+
}
273+
274+
// CHECK-LABEL: @test_msm
275+
func.func @test_msm() {
276+
// CHECK: %[[VAR1:.*]] = field.pf.constant 1 : ![[PF:.*]]
277+
%var1 = field.pf.constant 1 : !PF
278+
// CHECK: %[[VAR5:.*]] = field.pf.constant 5 : ![[PF]]
279+
%var5 = field.pf.constant 5 : !PF
280+
281+
// CHECK: %[[SCALARS:.*]] = tensor.from_elements %[[VAR1]], %[[VAR5]], %[[VAR5]] : [[TPF:.*]]
282+
%scalars = tensor.from_elements %var1, %var5, %var5 : tensor<3x!PF>
283+
// CHECK: %[[AFFINE1:.*]] = elliptic_curve.point %[[VAR1]], %[[VAR5]] : ![[PF]] -> ![[AF:.*]]
284+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
285+
// CHECK: %[[POINTS:.*]] = elliptic_curve.point_set.from_elements %[[AFFINE1]], %[[AFFINE1]], %[[AFFINE1]] : [[TAF:.*]]
286+
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
287+
// CHECK: %[[MSM_RESULT:.*]] = elliptic_curve.msm %[[SCALARS]], %[[POINTS]] : [[TPF]], [[TAF]] -> ![[JA:.*]]
288+
%msm_result = elliptic_curve.msm %scalars, %points : tensor<3x!PF>, tensor<3x!affine> -> !jacobian
289+
return
290+
}

tests/Dialect/EllipticCurve/elliptic_curve_to_field.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,30 @@ func.func @test_scalar_mul() {
176176
%xyzz3 = elliptic_curve.scalar_mul %var8, %xyzz1 : !PF, !xyzz -> !xyzz
177177
return
178178
}
179+
180+
// CHECK-LABEL: @test_point_set
181+
func.func @test_point_set() {
182+
%var1 = field.pf.constant 1 : !PF
183+
%var2 = field.pf.constant 2 : !PF
184+
%var4 = field.pf.constant 4 : !PF
185+
%var5 = field.pf.constant 5 : !PF
186+
%var8 = field.pf.constant 8 : !PF
187+
188+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
189+
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
190+
%idx1 = arith.constant 1 : index
191+
%point1 = elliptic_curve.point_set.extract %points, %idx1 : tensor<3x!affine> -> !affine
192+
%doubled = elliptic_curve.double %point1 : !affine -> !jacobian
193+
return
194+
}
195+
196+
func.func @test_msm() {
197+
%var1 = field.pf.constant 1 : !PF
198+
%var5 = field.pf.constant 5 : !PF
199+
200+
%scalars = tensor.from_elements %var1, %var5, %var5 : tensor<3x!PF>
201+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
202+
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
203+
%msm_result = elliptic_curve.msm %scalars, %points : tensor<3x!PF>, tensor<3x!affine> -> !jacobian
204+
return
205+
}

tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
44
// RUN: FileCheck %s --check-prefix=CHECK_TEST_OPS_IN_ORDER < %t
55

6+
// RUN: zkir-opt %s -elliptic-curve-to-llvm \
7+
// RUN: | mlir-runner -e test_msm -entry-point-result=void \
8+
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
9+
// RUN: FileCheck %s --check-prefix=CHECK_TEST_MSM < %t
10+
11+
612
!PF = !field.pf<11:i32>
713

814
#1 = #field.pf_elem<1:i32> : !PF
@@ -102,3 +108,50 @@ func.func @test_ops_in_order() {
102108
// CHECK_TEST_OPS_IN_ORDER: [0, 0, 0]
103109
// CHECK_TEST_OPS_IN_ORDER: [1, 1]
104110
// CHECK_TEST_OPS_IN_ORDER: [4, 3, 0, 0]
111+
112+
113+
// CHECK-LABEL: @test_msm
114+
func.func @test_msm() {
115+
// 5*(5,3,2) + 7*(1,2,2) + 3*(7,5,1) + 2*(3,2,7)
116+
%var1 = field.pf.constant 1 : !PF
117+
%var2 = field.pf.constant 2 : !PF
118+
%var3 = field.pf.constant 3 : !PF
119+
%var5 = field.pf.constant 5 : !PF
120+
%var7 = field.pf.constant 7 : !PF
121+
122+
%jacobian1 = elliptic_curve.point %var5, %var3, %var2 : !PF -> !jacobian
123+
%jacobian2 = elliptic_curve.point %var1, %var2, %var2 : !PF -> !jacobian
124+
%jacobian3 = elliptic_curve.point %var7, %var5, %var1 : !PF -> !jacobian
125+
%jacobian4 = elliptic_curve.point %var3, %var2, %var7 : !PF -> !jacobian
126+
127+
// CALCULATING TRUE VALUE OF MSM
128+
%scalar_mul1 = elliptic_curve.scalar_mul %var5, %jacobian1 : !PF, !jacobian -> !jacobian
129+
%scalar_mul2 = elliptic_curve.scalar_mul %var7, %jacobian2 : !PF, !jacobian -> !jacobian
130+
%scalar_mul3 = elliptic_curve.scalar_mul %var3, %jacobian3 : !PF, !jacobian -> !jacobian
131+
%scalar_mul4 = elliptic_curve.scalar_mul %var2, %jacobian4 : !PF, !jacobian -> !jacobian
132+
133+
%add1 = elliptic_curve.add %scalar_mul1, %scalar_mul2 : !jacobian, !jacobian -> !jacobian
134+
%add2 = elliptic_curve.add %scalar_mul3, %scalar_mul4 : !jacobian, !jacobian -> !jacobian
135+
%msm_true = elliptic_curve.add %add1, %add2 : !jacobian, !jacobian -> !jacobian
136+
137+
%extract_point = elliptic_curve.extract %msm_true : !jacobian -> tensor<3x!PF>
138+
%extract = field.pf.extract %extract_point : tensor<3x!PF> -> tensor<3xi32>
139+
%mem = bufferization.to_memref %extract : tensor<3xi32> to memref<3xi32>
140+
%U = memref.cast %mem : memref<3xi32> to memref<*xi32>
141+
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
142+
143+
// RUNNING MSM
144+
%scalars = tensor.from_elements %var5, %var7, %var3, %var2 : tensor<4x!PF>
145+
%points = elliptic_curve.point_set.from_elements %jacobian1, %jacobian2, %jacobian3, %jacobian4 : tensor<4x!jacobian>
146+
%msm_test = elliptic_curve.msm %scalars , %points: tensor<4x!PF>, tensor<4x!jacobian> -> !jacobian
147+
148+
%extract_point1 = elliptic_curve.extract %msm_test : !jacobian -> tensor<3x!PF>
149+
%extract1 = field.pf.extract %extract_point1 : tensor<3x!PF> -> tensor<3xi32>
150+
%mem1 = bufferization.to_memref %extract1 : tensor<3xi32> to memref<3xi32>
151+
%U1 = memref.cast %mem1 : memref<3xi32> to memref<*xi32>
152+
func.call @printMemrefI32(%U1) : (memref<*xi32>) -> ()
153+
return
154+
}
155+
156+
// CHECK_TEST_MSM: [0, 0, 0]
157+
// CHECK_TEST_MSM: [0, 0, 0]

zkir/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static mod_arith::ModArithType convertArithType(Type type) {
3636
}
3737

3838
static Type convertArithLikeType(ShapedType type) {
39-
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) {
39+
if (auto arithType = dyn_cast<IntegerType>(type.getElementType())) {
4040
return type.cloneWith(type.getShape(), convertArithType(arithType));
4141
}
4242
return type;

0 commit comments

Comments
 (0)