Skip to content

Commit e9243ad

Browse files
authored
Merge pull request #15 from a41-official/feat/impl-mont-mul
feat: implement montgomery multiplication
2 parents e65a9ff + f1f25f5 commit e9243ad

30 files changed

+1081
-120
lines changed

benchmark/benchmark.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,15 @@ def zkir_benchmark_test(name, mlir_src, test_src, zkir_opt_flags = [], data = []
124124
name = llvmir_target,
125125
src = generated_zkir_opt_name,
126126
pass_flags = ["--mlir-to-llvmir"],
127+
tags = tags,
127128
generated_filename = generated_llvmir_name,
128129
)
129130

130131
llc(
131132
name = obj_name,
132133
src = generated_llvmir_name,
133134
pass_flags = ["-relocation-model=pic", "-filetype=obj"],
135+
tags = tags,
134136
generated_filename = generated_obj_name,
135137
)
136138
cc_import(

benchmark/field/BUILD.bazel

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("//benchmark:benchmark.bzl", "zkir_benchmark_test")
2+
3+
zkir_benchmark_test(
4+
name = "mul_benchmark_test",
5+
mlir_src = "mul_benchmark.mlir",
6+
tags = ["manual"],
7+
test_src = ["mul_benchmark_test.cc"],
8+
zkir_opt_flags = ["-poly-to-llvm"],
9+
deps = [
10+
"//benchmark:BenchmarkUtils",
11+
],
12+
)

benchmark/field/mul_benchmark.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
!mod = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
2+
!F = !field.pf<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
3+
#mont = #mod_arith.montgomery<!mod>
4+
5+
func.func @mul(%arg0 : i256) -> i256 attributes { llvm.emit_c_interface } {
6+
%0 = field.pf.encapsulate %arg0 : i256 -> !F
7+
%1 = field.pf.mul %0, %0 : !F
8+
%2 = field.pf.mul %0, %1 : !F
9+
%3 = field.pf.extract %2 : !F -> i256
10+
return %3 : i256
11+
}
12+
13+
func.func @mont_mul(%arg0 : i256) -> i256 attributes { llvm.emit_c_interface } {
14+
%0 = field.pf.encapsulate %arg0 : i256 -> !F
15+
%1 = field.pf.mont_mul %0, %0 {montgomery = #mont} : !F
16+
%2 = field.pf.mont_mul %0, %1 {montgomery = #mont} : !F
17+
%3 = field.pf.extract %2 : !F -> i256
18+
return %3 : i256
19+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "benchmark/BenchmarkUtils.h"
2+
#include "benchmark/benchmark.h"
3+
#include "gtest/gtest.h"
4+
5+
namespace zkir {
6+
namespace {
7+
8+
using ::zkir::benchmark::Memref;
9+
10+
struct i256 {
11+
uint64_t limbs[4]; // 4 x 64 = 256 bits
12+
};
13+
14+
extern "C" void _mlir_ciface_mul(Memref<i256> *output, Memref<i256> *input);
15+
extern "C" void _mlir_ciface_mont_mul(Memref<i256> *output,
16+
Memref<i256> *input);
17+
18+
void BM_mul_benchmark(::benchmark::State &state) {
19+
Memref<i256> input(1, 1);
20+
21+
input.pget(0, 0)->limbs[0] = 0x0032131ffffffffff;
22+
input.pget(0, 0)->limbs[1] = 0x0032131ffffffffff;
23+
input.pget(0, 0)->limbs[2] = 0x0032131ffffffffff;
24+
input.pget(0, 0)->limbs[3] = 0x0032131ffffffffff;
25+
26+
Memref<i256> output(1, 1);
27+
for (auto _ : state) {
28+
_mlir_ciface_mul(&output, &input);
29+
}
30+
}
31+
32+
BENCHMARK(BM_mul_benchmark);
33+
34+
void BM_mont_mul_benchmark(::benchmark::State &state) {
35+
Memref<i256> input(1, 1);
36+
37+
input.pget(0, 0)->limbs[0] = 0x0032131ffffffffff;
38+
input.pget(0, 0)->limbs[1] = 0x0032131ffffffffff;
39+
input.pget(0, 0)->limbs[2] = 0x0032131ffffffffff;
40+
input.pget(0, 0)->limbs[3] = 0x0032131ffffffffff;
41+
42+
Memref<i256> mont_output(1, 1);
43+
for (auto _ : state) {
44+
_mlir_ciface_mont_mul(&mont_output, &input);
45+
}
46+
}
47+
48+
BENCHMARK(BM_mont_mul_benchmark);
49+
50+
} // namespace
51+
} // namespace zkir
52+
53+
// Run on (14 X 24 MHz CPU s)
54+
// CPU Caches:
55+
// L1 Data 64 KiB
56+
// L1 Instruction 128 KiB
57+
// L2 Unified 4096 KiB (x14)
58+
// Load Average: 7.70, 6.06, 6.06
59+
// ----------------------------------------------------------------
60+
// Benchmark Time CPU Iterations
61+
// ----------------------------------------------------------------
62+
// BM_mul_benchmark 2575 ns 2457 ns 294375
63+
// BM_mont_mul_benchmark 30.9 ns 30.2 ns 23041778

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
77
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
88

9+
!mod = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
10+
#mont = #mod_arith.montgomery<!mod>
11+
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
12+
913
func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } {
1014
%c42 = arith.constant 6420 : i256
1115
%full = tensor.splat %c42 : !intt_ty
@@ -25,3 +29,15 @@ func.func @intt(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface
2529
%1 = poly.intt %0 {root=#root} : !coefft_ty -> !poly_ty
2630
return %1 :!poly_ty
2731
}
32+
33+
func.func @ntt_mont(%arg0 : !poly_ty) -> !intt_ty attributes { llvm.emit_c_interface } {
34+
%0 = poly.ntt %arg0 {root=#root_mont} : !poly_ty -> !coefft_ty
35+
%1 = field.pf.extract %0 : !coefft_ty -> !intt_ty
36+
return %1 : !intt_ty
37+
}
38+
39+
func.func @intt_mont(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface } {
40+
%0 = field.pf.encapsulate %arg0 : !intt_ty -> !coefft_ty
41+
%1 = poly.intt %0 {root=#root_mont} : !coefft_ty -> !poly_ty
42+
return %1 :!poly_ty
43+
}

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ extern "C" void _mlir_ciface_input_generation(Memref<i256> *output);
1717
extern "C" void _mlir_ciface_ntt(Memref<i256> *output, Memref<i256> *input);
1818
extern "C" void _mlir_ciface_intt(Memref<i256> *output, Memref<i256> *input);
1919

20+
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *output,
21+
Memref<i256> *input);
22+
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *output,
23+
Memref<i256> *input);
24+
2025
void BM_ntt_benchmark(::benchmark::State &state) {
2126
Memref<i256> input(1, DEGREE);
2227
_mlir_ciface_input_generation(&input);
@@ -61,17 +66,66 @@ void BM_intt_benchmark(::benchmark::State &state) {
6166
// modifying the input. But I am not sure why ;(
6267
BENCHMARK(BM_intt_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
6368

69+
void BM_ntt_mont_benchmark(::benchmark::State &state) {
70+
Memref<i256> input(1, DEGREE);
71+
_mlir_ciface_input_generation(&input);
72+
73+
Memref<i256> ntt(1, DEGREE);
74+
for (auto _ : state) {
75+
_mlir_ciface_ntt_mont(&ntt, &input);
76+
}
77+
78+
Memref<i256> intt(1, DEGREE);
79+
_mlir_ciface_intt_mont(&intt, &ntt);
80+
81+
for (int i = 0; i < DEGREE; i++) {
82+
for (int j = 0; j < 4; j++) {
83+
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
84+
}
85+
}
86+
}
87+
88+
BENCHMARK(BM_ntt_mont_benchmark)->Unit(::benchmark::kSecond);
89+
90+
void BM_intt_mont_benchmark(::benchmark::State &state) {
91+
Memref<i256> input(1, DEGREE);
92+
_mlir_ciface_input_generation(&input);
93+
94+
Memref<i256> ntt(1, DEGREE);
95+
_mlir_ciface_ntt_mont(&ntt, &input);
96+
97+
Memref<i256> intt(1, DEGREE);
98+
for (auto _ : state) {
99+
_mlir_ciface_intt_mont(&intt, &ntt);
100+
}
101+
102+
for (int i = 0; i < DEGREE; i++) {
103+
for (int j = 0; j < 4; j++) {
104+
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
105+
}
106+
}
107+
}
108+
109+
// FIXME(batzor): It fails for more than 1 iteration so it seems like it is
110+
// modifying the input. But I am not sure why ;(
111+
BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
112+
64113
} // namespace
65114
} // namespace zkir
66115

116+
// clang-format off
117+
// NOLINTBEGIN(whitespace/line_length)
67118
// Run on (14 X 24 MHz CPU s)
68119
// CPU Caches:
69-
// L1 Data 64 KiB
70-
// L1 Instruction 128 KiB
71-
// L2 Unified 4096 KiB (x14)
72-
// Load Average: 22.54, 38.87, 26.62
73-
// -------------------------------------------------------------------------
74-
// Benchmark Time CPU Iterations
75-
// -------------------------------------------------------------------------
76-
// BM_ntt_benchmark 0.321 s 0.320 s 2
77-
// BM_intt_benchmark/iterations:1 0.475 s 0.473 s 1
120+
// L1 Data 64 KiB
121+
// L1 Instruction 128 KiB
122+
// L2 Unified 4096 KiB (x14)
123+
// Load Average: 27.66, 13.59, 9.67
124+
// ------------------------------------------------------------------------------
125+
// Benchmark Time CPU Iterations
126+
// ------------------------------------------------------------------------------
127+
// BM_ntt_benchmark 0.190 s 0.183 s 4
128+
// BM_intt_benchmark/iterations:1 0.381 s 0.368 s 1
129+
// BM_ntt_mont_benchmark 0.221 s 0.214 s 3
130+
// BM_intt_mont_benchmark/iterations:1 0.415 s 0.396 s 1
131+
// NOLINTEND()

tests/Dialect/Field/prime_field_to_mod_arith.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
!PFv = tensor<4x!PF1>
44
#elem = #field.pf_elem<31:i32> : !PF1
55

6+
!mod = !mod_arith.int<3 : i32>
7+
#mont = #mod_arith.montgomery<!mod>
8+
69
// CHECK-LABEL: @test_lower_constant
710
// CHECK-SAME: () -> [[T:.*]] {
811
func.func @test_lower_constant() -> !PF1 {
@@ -52,6 +55,46 @@ func.func @test_lower_extract_vec(%lhs : tensor<4x!PF1>) -> tensor<4xi32> {
5255
return %res : tensor<4xi32>
5356
}
5457

58+
// CHECK-LABEL: @test_lower_to_mont
59+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
60+
func.func @test_lower_to_mont(%lhs : !PF1) -> !PF1 {
61+
// CHECK-NOT: field.pf.to_mont
62+
// CHECK: %[[RES:.*]] = mod_arith.to_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
63+
%res = field.pf.to_mont %lhs {montgomery=#mont} : !PF1
64+
// CHECK: return %[[RES]] : [[T]]
65+
return %res : !PF1
66+
}
67+
68+
// CHECK-LABEL: @test_lower_to_mont_vec
69+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
70+
func.func @test_lower_to_mont_vec(%lhs : !PFv) -> !PFv {
71+
// CHECK-NOT: field.pf.to_mont
72+
// CHECK: %[[RES:.*]] = mod_arith.to_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
73+
%res = field.pf.to_mont %lhs {montgomery=#mont} : !PFv
74+
// CHECK: return %[[RES]] : [[T]]
75+
return %res : !PFv
76+
}
77+
78+
// CHECK-LABEL: @test_lower_from_mont
79+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
80+
func.func @test_lower_from_mont(%lhs : !PF1) -> !PF1 {
81+
// CHECK-NOT: field.pf.from_mont
82+
// CHECK: %[[RES:.*]] = mod_arith.from_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
83+
%res = field.pf.from_mont %lhs {montgomery=#mont} : !PF1
84+
// CHECK: return %[[RES]] : [[T]]
85+
return %res : !PF1
86+
}
87+
88+
// CHECK-LABEL: @test_lower_from_mont_vec
89+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
90+
func.func @test_lower_from_mont_vec(%lhs : !PFv) -> !PFv {
91+
// CHECK-NOT: field.pf.from_mont
92+
// CHECK: %[[RES:.*]] = mod_arith.from_mont %[[LHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
93+
%res = field.pf.from_mont %lhs {montgomery=#mont} : !PFv
94+
// CHECK: return %[[RES]] : [[T]]
95+
return %res : !PFv
96+
}
97+
5598
// CHECK-LABEL: @test_lower_inverse
5699
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
57100
func.func @test_lower_inverse(%lhs : !PF1) -> !PF1 {
@@ -135,6 +178,27 @@ func.func @test_lower_mul_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
135178
return %res : !PFv
136179
}
137180

181+
// CHECK-LABEL: @test_lower_mont_mul
182+
// CHECK-SAME: () -> [[T:.*]] {
183+
func.func @test_lower_mont_mul() -> !PF1 {
184+
// CHECK: %[[C0:.*]] = mod_arith.constant 2 : [[T]]
185+
%c0 = field.pf.constant 2 : !PF1
186+
// CHECK: %[[RES:.*]] = mod_arith.mont_mul %[[C0]], %[[C0]] {montgomery = #mod_arith.montgomery<[[T]]>} : [[T]]
187+
%res = field.pf.mont_mul %c0, %c0 {montgomery = #mont} : !PF1
188+
// CHECK: return %[[RES]] : [[T]]
189+
return %res : !PF1
190+
}
191+
192+
// CHECK-LABEL: @test_lower_mont_mul_vec
193+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] {
194+
func.func @test_lower_mont_mul_vec(%lhs : !PFv, %rhs : !PFv) -> !PFv {
195+
// CHECK-NOT: field.pf.mont_mul
196+
// CHECK: %[[RES:.*]] = mod_arith.mont_mul %[[LHS]], %[[RHS]] {montgomery = #mod_arith.montgomery<[[E:.*]]>} : [[T]]
197+
%res = field.pf.mont_mul %lhs, %rhs {montgomery = #mont} : !PFv
198+
// CHECK: return %[[RES]] : [[T]]
199+
return %res : !PFv
200+
}
201+
138202
// CHECK-LABEL: @test_lower_constant_tensor
139203
// CHECK-SAME: () -> [[T:.*]] {
140204
func.func @test_lower_constant_tensor() -> !PFv {

tests/Dialect/ModArith/mod_arith_runner.mlir

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1-
// RUN: zkir-opt %s --mod-arith-to-arith -convert-elementwise-to-linalg --one-shot-bufferize --convert-scf-to-cf --convert-cf-to-llvm --convert-to-llvm \
1+
// RUN: zkir-opt %s --mod-arith-to-arith -convert-elementwise-to-linalg --one-shot-bufferize --convert-scf-to-cf --convert-cf-to-llvm --convert-to-llvm --convert-vector-to-llvm \
22
// RUN: | mlir-runner -e test_lower_inverse -entry-point-result=void \
33
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
44
// RUN: FileCheck %s --check-prefix=CHECK_TEST_INVERSE < %t
55

6+
// RUN: zkir-opt %s --mod-arith-to-arith -convert-elementwise-to-linalg --one-shot-bufferize --convert-scf-to-cf --convert-cf-to-llvm --convert-to-llvm --convert-vector-to-llvm \
7+
// RUN: | mlir-runner -e test_lower_mont_reduce -entry-point-result=void \
8+
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
9+
// RUN: FileCheck %s --check-prefix=CHECK_TEST_MONT_REDUCE < %t
10+
11+
// RUN: zkir-opt %s --mod-arith-to-arith -convert-elementwise-to-linalg --one-shot-bufferize --convert-scf-to-cf --convert-cf-to-llvm --convert-to-llvm --convert-vector-to-llvm \
12+
// RUN: | mlir-runner -e test_lower_mont_mul -entry-point-result=void \
13+
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
14+
// RUN: FileCheck %s --check-prefix=CHECK_TEST_MONT_MUL < %t
15+
616
!Fr = !mod_arith.int<2147483647:i32>
717

818
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }
@@ -20,3 +30,44 @@ func.func @test_lower_inverse() {
2030
}
2131

2232
// CHECK_TEST_INVERSE: [1324944920]
33+
34+
!Fq = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
35+
#Fq_mont = #mod_arith.montgomery<!Fq>
36+
37+
func.func @test_lower_mont_reduce() {
38+
%p = arith.constant 3723 : i512
39+
%p_mont = mod_arith.mont_reduce %p {montgomery=#Fq_mont} : i512 -> !Fq
40+
41+
%2 = mod_arith.extract %p_mont : !Fq -> i256
42+
%3 = vector.from_elements %2 : vector<1xi256>
43+
%4 = vector.bitcast %3 : vector<1xi256> to vector<8xi32>
44+
%mem = memref.alloc() : memref<8xi32>
45+
%idx_0 = arith.constant 0 : index
46+
vector.store %4, %mem[%idx_0] : memref<8xi32>, vector<8xi32>
47+
48+
%U = memref.cast %mem : memref<8xi32> to memref<*xi32>
49+
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
50+
return
51+
}
52+
53+
// CHECK_TEST_MONT_REDUCE: [-1635059004, -1772563805, -2074116324, -156049350, 156881531, -524227392, -1359481138, 438709201]
54+
55+
func.func @test_lower_mont_mul() {
56+
%p = mod_arith.constant 17221657567640823606390383439573883756117969501024189775361 : !Fq
57+
%p_mont = mod_arith.to_mont %p {montgomery=#Fq_mont} : !Fq
58+
%p_mont_sq = mod_arith.mont_mul %p_mont, %p_mont {montgomery=#Fq_mont} : !Fq
59+
%p_sq = mod_arith.from_mont %p_mont_sq {montgomery=#Fq_mont} : !Fq
60+
61+
%2 = mod_arith.extract %p_sq : !Fq -> i256
62+
%3 = vector.from_elements %2 : vector<1xi256>
63+
%4 = vector.bitcast %3 : vector<1xi256> to vector<8xi32>
64+
%mem = memref.alloc() : memref<8xi32>
65+
%idx_0 = arith.constant 0 : index
66+
vector.store %4, %mem[%idx_0] : memref<8xi32>, vector<8xi32>
67+
68+
%U = memref.cast %mem : memref<8xi32> to memref<*xi32>
69+
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
70+
return
71+
}
72+
73+
// CHECK_TEST_MONT_MUL: [-1717936988, -857005375, 1976922116, -1939796685, 1587159113, 557631023, 126776667, 742573744]

0 commit comments

Comments
 (0)