Skip to content

Commit 4af58a4

Browse files
authored
Merge pull request #18 from a41-official/perf/optimize-memory-usage
perf: optimize memory usage
2 parents dc44ad9 + bd3ab9b commit 4af58a4

File tree

18 files changed

+316
-213
lines changed

18 files changed

+316
-213
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ jobs:
2828
with:
2929
bazelisk-cache: true
3030
disk-cache: ${{ runner.os }}-zkir_bazelbuild
31-
repository-cache: false
32-
external-cache: false
31+
repository-cache: true
32+
external-cache: true
3333

3434
- name: Run `bazel build`
3535
run: |

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
!coeff_ty = !field.pf<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
2-
!poly_ty = !poly.polynomial<!coeff_ty, 1048575>
32
!coefft_ty = tensor<1048576x!coeff_ty>
4-
!intt_ty = tensor<1048576xi256>
3+
!memref_ty = memref<1048576x!coeff_ty>
54

65
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
76
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
@@ -10,34 +9,26 @@
109
#mont = #mod_arith.montgomery<!mod>
1110
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
1211

13-
func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } {
14-
%c42 = arith.constant 6420 : i256
15-
%full = tensor.splat %c42 : !intt_ty
16-
%coeffs = field.pf.encapsulate %full : !intt_ty -> !coefft_ty
17-
%poly = poly.from_tensor %coeffs : !coefft_ty -> !poly_ty
18-
return %poly : !poly_ty
12+
func.func @ntt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
13+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
14+
poly.ntt %t {root=#root} : !coefft_ty
15+
return
1916
}
2017

21-
func.func @ntt(%arg0 : !poly_ty) -> !intt_ty attributes { llvm.emit_c_interface } {
22-
%0 = poly.ntt %arg0 {root=#root} : !poly_ty -> !coefft_ty
23-
%1 = field.pf.extract %0 : !coefft_ty -> !intt_ty
24-
return %1 : !intt_ty
18+
func.func @intt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
19+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
20+
poly.intt %t {root=#root} : !coefft_ty
21+
return
2522
}
2623

27-
func.func @intt(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface } {
28-
%0 = field.pf.encapsulate %arg0 : !intt_ty -> !coefft_ty
29-
%1 = poly.intt %0 {root=#root} : !coefft_ty -> !poly_ty
30-
return %1 :!poly_ty
24+
func.func @ntt_mont(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
25+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
26+
poly.ntt %t {root=#root_mont} : !coefft_ty
27+
return
3128
}
3229

33-
func.func @ntt_mont(%arg0 : !poly_ty) -> !intt_ty attributes { llvm.emit_c_interface } {
34-
%0 = poly.ntt %arg0 {root=#root_mont} : !poly_ty -> !coefft_ty
35-
%1 = field.pf.extract %0 : !coefft_ty -> !intt_ty
36-
return %1 : !intt_ty
37-
}
38-
39-
func.func @intt_mont(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface } {
40-
%0 = field.pf.encapsulate %arg0 : !intt_ty -> !coefft_ty
41-
%1 = poly.intt %0 {root=#root_mont} : !coefft_ty -> !poly_ty
42-
return %1 :!poly_ty
30+
func.func @intt_mont(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
31+
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty
32+
poly.intt %t {root=#root_mont} : !coefft_ty
33+
return
4334
}

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,101 +25,106 @@ static void fillWithRandom(Memref<i256> *input, const i256 &kPrime) {
2525
std::mt19937_64 rng(std::random_device{}()); // NOLINT(whitespace/braces)
2626
std::uniform_int_distribution<uint64_t> dist(0, UINT64_MAX);
2727
for (int i = 0; i < NUM_COEFFS; i++) {
28-
*input->pget(0, i) = i256::randomLT(kPrime, rng, dist);
28+
*input->pget(i, 0) = i256::randomLT(kPrime, rng, dist);
2929
}
3030
}
3131

32-
extern "C" void _mlir_ciface_input_generation(Memref<i256> *output);
33-
extern "C" void _mlir_ciface_ntt(Memref<i256> *output, Memref<i256> *input);
34-
extern "C" void _mlir_ciface_intt(Memref<i256> *output, Memref<i256> *input);
32+
extern "C" void _mlir_ciface_ntt(Memref<i256> *buffer);
33+
extern "C" void _mlir_ciface_intt(Memref<i256> *buffer);
3534

36-
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *output,
37-
Memref<i256> *input);
38-
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *output,
39-
Memref<i256> *input);
35+
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *buffer);
36+
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *buffer);
4037

4138
void BM_ntt_benchmark(::benchmark::State &state) {
42-
Memref<i256> input(1, NUM_COEFFS);
43-
_mlir_ciface_input_generation(&input);
39+
Memref<i256> input(NUM_COEFFS, 1);
4440
fillWithRandom(&input, kPrime);
4541

46-
Memref<i256> ntt(1, NUM_COEFFS);
42+
Memref<i256> ntt(NUM_COEFFS, 1);
4743
for (auto _ : state) {
48-
_mlir_ciface_ntt(&ntt, &input);
44+
state.PauseTiming();
45+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
46+
state.ResumeTiming();
47+
_mlir_ciface_ntt(&ntt);
4948
}
5049

51-
Memref<i256> intt(1, NUM_COEFFS);
52-
_mlir_ciface_intt(&intt, &ntt);
50+
_mlir_ciface_intt(&ntt);
5351

5452
for (int i = 0; i < NUM_COEFFS; i++) {
5553
for (int j = 0; j < 4; j++) {
56-
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
54+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
5755
}
5856
}
5957
}
6058

6159
BENCHMARK(BM_ntt_benchmark)->Unit(::benchmark::kMillisecond);
6260

6361
void BM_intt_benchmark(::benchmark::State &state) {
64-
Memref<i256> input(1, NUM_COEFFS);
65-
_mlir_ciface_input_generation(&input);
62+
Memref<i256> input(NUM_COEFFS, 1);
6663
fillWithRandom(&input, kPrime);
6764

68-
Memref<i256> ntt(1, NUM_COEFFS);
69-
_mlir_ciface_ntt(&ntt, &input);
65+
Memref<i256> ntt(NUM_COEFFS, 1);
66+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
67+
_mlir_ciface_ntt(&ntt);
7068

71-
Memref<i256> intt(1, NUM_COEFFS);
69+
Memref<i256> intt(NUM_COEFFS, 1);
7270
for (auto _ : state) {
73-
_mlir_ciface_intt(&intt, &ntt);
71+
state.PauseTiming();
72+
memcpy(intt.pget(0, 0), ntt.pget(0, 0), sizeof(i256) * NUM_COEFFS);
73+
state.ResumeTiming();
74+
_mlir_ciface_intt(&ntt);
7475
}
7576

7677
for (int i = 0; i < NUM_COEFFS; i++) {
7778
for (int j = 0; j < 4; j++) {
78-
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
79+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
7980
}
8081
}
8182
}
8283

8384
BENCHMARK(BM_intt_benchmark)->Unit(::benchmark::kMillisecond);
8485

8586
void BM_ntt_mont_benchmark(::benchmark::State &state) {
86-
Memref<i256> input(1, NUM_COEFFS);
87-
_mlir_ciface_input_generation(&input);
87+
Memref<i256> input(NUM_COEFFS, 1);
8888
fillWithRandom(&input, kPrime);
8989

90-
Memref<i256> ntt(1, NUM_COEFFS);
90+
Memref<i256> ntt(NUM_COEFFS, 1);
9191
for (auto _ : state) {
92-
_mlir_ciface_ntt_mont(&ntt, &input);
92+
state.PauseTiming();
93+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
94+
state.ResumeTiming();
95+
_mlir_ciface_ntt_mont(&ntt);
9396
}
9497

95-
Memref<i256> intt(1, NUM_COEFFS);
96-
_mlir_ciface_intt_mont(&intt, &ntt);
98+
_mlir_ciface_intt_mont(&ntt);
9799

98100
for (int i = 0; i < NUM_COEFFS; i++) {
99101
for (int j = 0; j < 4; j++) {
100-
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
102+
EXPECT_EQ(ntt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
101103
}
102104
}
103105
}
104106

105107
BENCHMARK(BM_ntt_mont_benchmark)->Unit(::benchmark::kMillisecond);
106108

107109
void BM_intt_mont_benchmark(::benchmark::State &state) {
108-
Memref<i256> input(1, NUM_COEFFS);
109-
_mlir_ciface_input_generation(&input);
110+
Memref<i256> input(NUM_COEFFS, 1);
110111
fillWithRandom(&input, kPrime);
111112

112-
Memref<i256> ntt(1, NUM_COEFFS);
113-
_mlir_ciface_ntt_mont(&ntt, &input);
113+
Memref<i256> ntt(NUM_COEFFS, 1);
114+
memcpy(ntt.pget(0, 0), input.pget(0, 0), sizeof(i256) * NUM_COEFFS);
115+
_mlir_ciface_ntt_mont(&ntt);
114116

115-
Memref<i256> intt(1, NUM_COEFFS);
117+
Memref<i256> intt(NUM_COEFFS, 1);
116118
for (auto _ : state) {
117-
_mlir_ciface_intt_mont(&intt, &ntt);
119+
state.PauseTiming();
120+
memcpy(intt.pget(0, 0), ntt.pget(0, 0), sizeof(i256) * NUM_COEFFS);
121+
state.ResumeTiming();
122+
_mlir_ciface_intt_mont(&intt);
118123
}
119124

120125
for (int i = 0; i < NUM_COEFFS; i++) {
121126
for (int j = 0; j < 4; j++) {
122-
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
127+
EXPECT_EQ(intt.pget(i, 0)->limbs[j], input.pget(i, 0)->limbs[j]);
123128
}
124129
}
125130
}
@@ -136,12 +141,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Unit(::benchmark::kMillisecond);
136141
// L1 Data 64 KiB
137142
// L1 Instruction 128 KiB
138143
// L2 Unified 4096 KiB (x14)
139-
// Load Average: 6.49, 5.64, 5.49
140-
// -------------------------------------------------------------------------
141-
// Benchmark Time CPU Iterations
142-
// -------------------------------------------------------------------------
143-
// BM_ntt_benchmark 1656 ms 1050 ms 1
144-
// BM_intt_benchmark/iterations:1 1791 ms 1090 ms 1
145-
// BM_ntt_mont_benchmark 38.6 ms 18.6 ms 40
146-
// BM_intt_mont_benchmark 99.4 ms 56.4 ms 11
144+
// Load Average: 8.66, 7.19, 7.37
145+
// -----------------------------------------------------------------
146+
// Benchmark Time CPU Iterations
147+
// -----------------------------------------------------------------
148+
// BM_ntt_benchmark 1603 ms 1085 ms 1
149+
// BM_intt_benchmark 1585 ms 1120 ms 1
150+
// BM_ntt_mont_benchmark 34.7 ms 16.8 ms 42
151+
// BM_intt_mont_benchmark 33.8 ms 16.6 ms 42
147152
// NOLINTEND()

tests/Dialect/Poly/poly_canonicalization.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !poly_ty) -> !poly_ty {
1212
// CHECK-NOT: poly.ntt
1313
// CHECK-NOT: poly.intt
1414
// CHECK: %[[RESULT:.*]] = poly.add %[[P]], %[[P]] : [[T]]
15-
%t0 = poly.ntt %p0 {root=#root} : !poly_ty -> !tensor_ty
16-
%p1 = poly.intt %t0 {root=#root} : !tensor_ty -> !poly_ty
15+
%coeffs = poly.to_tensor %p0 : !poly_ty -> !tensor_ty
16+
%evals = poly.ntt %coeffs {root=#root} : !tensor_ty
17+
%coeffs1 = poly.intt %evals {root=#root} : !tensor_ty
18+
%p1 = poly.from_tensor %coeffs1 : !tensor_ty -> !poly_ty
1719
%p2 = poly.add %p1, %p1 : !poly_ty
1820
// CHECK: return %[[RESULT]] : [[T]]
1921
return %p2 : !poly_ty
@@ -25,9 +27,9 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
2527
// CHECK-NOT: poly.intt
2628
// CHECK-NOT: poly.ntt
2729
// CHECK: %[[RESULT:.*]] = field.pf.add %[[X]], %[[X]] : [[T]]
28-
%p0 = poly.intt %t0 {root=#root} : !tensor_ty -> !poly_ty
29-
%t1 = poly.ntt %p0 {root=#root} : !poly_ty -> !tensor_ty
30-
%t2 = field.pf.add %t1, %t1 : !tensor_ty
30+
%coeffs = poly.intt %t0 {root=#root} : !tensor_ty
31+
%evals = poly.ntt %coeffs {root=#root} : !tensor_ty
32+
%evals2 = field.pf.add %evals, %evals : !tensor_ty
3133
// CHECK: return %[[RESULT]] : [[T]]
32-
return %t2 : !tensor_ty
34+
return %evals2 : !tensor_ty
3335
}

tests/Dialect/Poly/poly_ntt_runner.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interf
1414
func.func @test_poly_ntt() {
1515
%coeffsRaw = arith.constant dense<[1,2,3,4]> : tensor<4xi32>
1616
%coeffs = field.pf.encapsulate %coeffsRaw : tensor<4xi32> -> tensor<4x!coeff_ty>
17-
%poly = poly.from_tensor %coeffs : tensor<4x!coeff_ty> -> !poly_ty
18-
%res = poly.ntt %poly {root=#root} : !poly_ty -> tensor<4x!coeff_ty>
17+
%res = poly.ntt %coeffs {root=#root} : tensor<4x!coeff_ty>
1918

2019
%extract = field.pf.extract %res : tensor<4x!coeff_ty> -> tensor<4xi32>
2120
%1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32>
2221
%U = memref.cast %1 : memref<4xi32> to memref<*xi32>
2322
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
2423

25-
%intt = poly.intt %res {root=#root} : tensor<4x!coeff_ty> -> !poly_ty
26-
%res2 = poly.to_tensor %intt : !poly_ty -> tensor<4x!coeff_ty>
24+
%intt = poly.intt %res {root=#root} : tensor<4x!coeff_ty>
25+
%poly = poly.from_tensor %intt : tensor<4x!coeff_ty> -> !poly_ty
26+
%res2 = poly.to_tensor %poly : !poly_ty -> tensor<4x!coeff_ty>
2727
%extract2 = field.pf.extract %res2 : tensor<4x!coeff_ty> -> tensor<4xi32>
2828
%2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32>
2929
%U2 = memref.cast %2 : memref<4xi32> to memref<*xi32>

tests/Dialect/Poly/poly_to_field.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ func.func @test_lower_from_tensor(%t : tensor<4x!PF1>) -> !poly_ty1 {
6767
}
6868

6969
// CHECK-LABEL: @test_lower_ntt
70-
// CHECK-SAME: (%[[INPUT:.*]]: [[P:.*]]) -> [[T:.*]] {
71-
func.func @test_lower_ntt(%input : !poly_ty1) -> tensor<4x!PF1> {
70+
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[T]] {
71+
func.func @test_lower_ntt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
7272
// CHECK-NOT: poly.ntt
73-
%res = poly.ntt %input {root=#root} : !poly_ty1 -> tensor<4x!PF1>
73+
%res = poly.ntt %input {root=#root} : tensor<4x!PF1>
7474
return %res: tensor<4x!PF1>
7575
}
7676

7777
// CHECK-LABEL: @test_lower_intt
7878
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[P:.*]] {
79-
func.func @test_lower_intt(%input : tensor<4x!PF1>) -> !poly_ty1 {
79+
func.func @test_lower_intt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
8080
// CHECK-NOT: poly.intt
81-
%res = poly.intt %input {root=#root} : tensor<4x!PF1> -> !poly_ty1
82-
return %res: !poly_ty1
81+
%res = poly.intt %input {root=#root} : tensor<4x!PF1>
82+
return %res: tensor<4x!PF1>
8383
}

zkir/Dialect/Field/Conversions/FieldToModArith/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
"@llvm-project//mlir:BufferizationDialect",
2222
"@llvm-project//mlir:IR",
2323
"@llvm-project//mlir:LinalgDialect",
24+
"@llvm-project//mlir:MemRefDialect",
2425
"@llvm-project//mlir:Pass",
2526
"@llvm-project//mlir:Support",
2627
"@llvm-project//mlir:TensorDialect",

zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Dialect/Affine/IR/AffineOps.h"
77
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
88
#include "mlir/Dialect/Linalg/IR/Linalg.h"
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
910
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1011
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1112
#include "mlir/IR/BuiltinAttributes.h"
@@ -351,8 +352,10 @@ void PrimeFieldToModArith::runOnOperation() {
351352
ConvertAny<affine::AffineForOp>, ConvertAny<affine::AffineParallelOp>,
352353
ConvertAny<affine::AffineLoadOp>, ConvertAny<affine::AffineStoreOp>,
353354
ConvertAny<affine::AffineYieldOp>, ConvertAny<linalg::GenericOp>,
354-
ConvertAny<linalg::YieldOp>, ConvertAny<tensor::CastOp>,
355-
ConvertAny<tensor::ExtractOp>, ConvertAny<tensor::FromElementsOp>,
355+
ConvertAny<linalg::MapOp>, ConvertAny<memref::LoadOp>,
356+
ConvertAny<memref::StoreOp>, ConvertAny<linalg::YieldOp>,
357+
ConvertAny<tensor::CastOp>, ConvertAny<tensor::ExtractOp>,
358+
ConvertAny<tensor::FromElementsOp>,
356359
ConvertAny<bufferization::MaterializeInDestinationOp>,
357360
ConvertAny<bufferization::ToMemrefOp>,
358361
ConvertAny<bufferization::ToTensorOp>, ConvertAny<tensor::InsertOp>>(
@@ -364,9 +367,10 @@ void PrimeFieldToModArith::runOnOperation() {
364367
affine::AffineForOp, affine::AffineParallelOp, affine::AffineLoadOp,
365368
affine::AffineStoreOp, affine::AffineYieldOp,
366369
bufferization::MaterializeInDestinationOp, bufferization::ToMemrefOp,
367-
bufferization::ToTensorOp, linalg::GenericOp, linalg::YieldOp,
368-
tensor::CastOp, tensor::ExtractOp, tensor::FromElementsOp,
369-
tensor::InsertOp>([&](auto op) { return typeConverter.isLegal(op); });
370+
bufferization::ToTensorOp, linalg::GenericOp, linalg::MapOp,
371+
linalg::YieldOp, memref::LoadOp, memref::StoreOp, tensor::CastOp,
372+
tensor::ExtractOp, tensor::FromElementsOp, tensor::InsertOp>(
373+
[&](auto op) { return typeConverter.isLegal(op); });
370374

371375
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
372376
signalPassFailure();

zkir/Dialect/ModArith/Conversions/ModArithToArith/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
"@llvm-project//mlir:BufferizationDialect",
2222
"@llvm-project//mlir:IR",
2323
"@llvm-project//mlir:LinalgDialect",
24+
"@llvm-project//mlir:MemRefDialect",
2425
"@llvm-project//mlir:Pass",
2526
"@llvm-project//mlir:SCFDialect",
2627
"@llvm-project//mlir:Support",

0 commit comments

Comments
 (0)