Skip to content

Commit ac2f281

Browse files
committed
feat(benchmark): in-place NTT benchmark
1 parent 24615d0 commit ac2f281

File tree

2 files changed

+54
-39
lines changed

2 files changed

+54
-39
lines changed

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
!coeff_ty = !field.pf<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
22
!coefft_ty = tensor<1048576x!coeff_ty>
3+
!memref_ty = memref<1048576x!coeff_ty>
34

45
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
56
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
@@ -8,22 +9,26 @@
89
#mont = #mod_arith.montgomery<!mod>
910
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
1011

11-
func.func @ntt(%arg0 : !coefft_ty) -> !coefft_ty attributes { llvm.emit_c_interface } {
12-
%0 = poly.ntt %arg0 {root=#root} : !coefft_ty
13-
return %0 : !coefft_ty
12+
func.func @ntt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
13+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
14+
poly.ntt %t {root=#root} : !coefft_ty
15+
return
1416
}
1517

16-
func.func @intt(%arg0 : !coefft_ty) -> !coefft_ty attributes { llvm.emit_c_interface } {
17-
%0 = poly.intt %arg0 {root=#root} : !coefft_ty
18-
return %0 : !coefft_ty
18+
func.func @intt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
19+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
20+
poly.intt %t {root=#root} : !coefft_ty
21+
return
1922
}
2023

21-
func.func @ntt_mont(%arg0 : !coefft_ty) -> !coefft_ty attributes { llvm.emit_c_interface } {
22-
%0 = poly.ntt %arg0 {root=#root_mont} : !coefft_ty
23-
return %0 : !coefft_ty
24+
func.func @ntt_mont(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
25+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
26+
poly.ntt %t {root=#root_mont} : !coefft_ty
27+
return
2428
}
2529

26-
func.func @intt_mont(%arg0 : !coefft_ty) -> !coefft_ty attributes { llvm.emit_c_interface } {
27-
%0 = poly.intt %arg0 {root=#root_mont} : !coefft_ty
28-
return %0 : !coefft_ty
30+
func.func @intt_mont(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
31+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
32+
poly.intt %t {root=#root_mont} : !coefft_ty
33+
return
2934
}

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,29 @@ static void fillWithRandom(Memref<i256> *input, const i256 &kPrime) {
2929
}
3030
}
3131

32-
extern "C" void _mlir_ciface_ntt(Memref<i256> *output, Memref<i256> *input);
33-
extern "C" void _mlir_ciface_intt(Memref<i256> *output, Memref<i256> *input);
32+
extern "C" void _mlir_ciface_ntt(Memref<i256> *buffer);
33+
extern "C" void _mlir_ciface_intt(Memref<i256> *buffer);
3434

35-
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *output,
36-
Memref<i256> *input);
37-
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *output,
38-
Memref<i256> *input);
35+
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *buffer);
36+
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *buffer);
3937

4038
void BM_ntt_benchmark(::benchmark::State &state) {
4139
Memref<i256> input(NUM_COEFFS, 1);
4240
fillWithRandom(&input, kPrime);
4341

4442
Memref<i256> ntt(NUM_COEFFS, 1);
4543
for (auto _ : state) {
46-
_mlir_ciface_ntt(&ntt, &input);
44+
state.PauseTiming();
45+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
46+
state.ResumeTiming();
47+
_mlir_ciface_ntt(&ntt);
4748
}
4849

49-
Memref<i256> intt(NUM_COEFFS, 1);
50-
_mlir_ciface_intt(&intt, &ntt);
50+
_mlir_ciface_intt(&ntt);
5151

5252
for (int i = 0; i < NUM_COEFFS; i++) {
5353
for (int j = 0; j < 4; j++) {
54-
EXPECT_EQ(intt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
54+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
5555
}
5656
}
5757
}
@@ -63,16 +63,20 @@ void BM_intt_benchmark(::benchmark::State &state) {
6363
fillWithRandom(&input, kPrime);
6464

6565
Memref<i256> ntt(NUM_COEFFS, 1);
66-
_mlir_ciface_ntt(&ntt, &input);
66+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
67+
_mlir_ciface_ntt(&ntt);
6768

6869
Memref<i256> intt(NUM_COEFFS, 1);
6970
for (auto _ : state) {
70-
_mlir_ciface_intt(&intt, &ntt);
71+
state.PauseTiming();
72+
memcpy(intt.pget(0, 0), ntt.pget(0, 0), sizeof(i256) * NUM_COEFFS);
73+
state.ResumeTiming();
74+
_mlir_ciface_intt(&ntt);
7175
}
7276

7377
for (int i = 0; i < NUM_COEFFS; i++) {
7478
for (int j = 0; j < 4; j++) {
75-
EXPECT_EQ(intt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
79+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
7680
}
7781
}
7882
}
@@ -85,15 +89,17 @@ void BM_ntt_mont_benchmark(::benchmark::State &state) {
8589

8690
Memref<i256> ntt(NUM_COEFFS, 1);
8791
for (auto _ : state) {
88-
_mlir_ciface_ntt_mont(&ntt, &input);
92+
state.PauseTiming();
93+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
94+
state.ResumeTiming();
95+
_mlir_ciface_ntt_mont(&ntt);
8996
}
9097

91-
Memref<i256> intt(NUM_COEFFS, 1);
92-
_mlir_ciface_intt_mont(&intt, &ntt);
98+
_mlir_ciface_intt_mont(&ntt);
9399

94100
for (int i = 0; i < NUM_COEFFS; i++) {
95101
for (int j = 0; j < 4; j++) {
96-
EXPECT_EQ(intt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
102+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
97103
}
98104
}
99105
}
@@ -105,11 +111,15 @@ void BM_intt_mont_benchmark(::benchmark::State &state) {
105111
fillWithRandom(&input, kPrime);
106112

107113
Memref<i256> ntt(NUM_COEFFS, 1);
108-
_mlir_ciface_ntt_mont(&ntt, &input);
114+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
115+
_mlir_ciface_ntt_mont(&ntt);
109116

110117
Memref<i256> intt(NUM_COEFFS, 1);
111118
for (auto _ : state) {
112-
_mlir_ciface_intt_mont(&intt, &ntt);
119+
state.PauseTiming();
120+
memcpy(intt.pget(0, 0), ntt.pget(0, 0), sizeof(i256) * NUM_COEFFS);
121+
state.ResumeTiming();
122+
_mlir_ciface_intt_mont(&intt);
113123
}
114124

115125
for (int i = 0; i < NUM_COEFFS; i++) {
@@ -131,12 +141,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Unit(::benchmark::kMillisecond);
131141
// L1 Data 64 KiB
132142
// L1 Instruction 128 KiB
133143
// L2 Unified 4096 KiB (x14)
134-
// Load Average: 6.49, 5.64, 5.49
135-
// -------------------------------------------------------------------------
136-
// Benchmark Time CPU Iterations
137-
// -------------------------------------------------------------------------
138-
// BM_ntt_benchmark 1656 ms 1050 ms 1
139-
// BM_intt_benchmark/iterations:1 1791 ms 1090 ms 1
140-
// BM_ntt_mont_benchmark 38.6 ms 18.6 ms 40
141-
// BM_intt_mont_benchmark 99.4 ms 56.4 ms 11
144+
// Load Average: 8.66, 7.19, 7.37
145+
// -----------------------------------------------------------------
146+
// Benchmark Time CPU Iterations
147+
// -----------------------------------------------------------------
148+
// BM_ntt_benchmark 1603 ms 1085 ms 1
149+
// BM_intt_benchmark 1585 ms 1120 ms 1
150+
// BM_ntt_mont_benchmark 34.7 ms 16.8 ms 42
151+
// BM_intt_mont_benchmark 33.8 ms 16.6 ms 42
142152
// NOLINTEND()

0 commit comments

Comments
 (0)