Skip to content

Commit 8ca44d6

Browse files
committed
perf(poly): parallelize inner loops in NTT
1 parent 6b8f5b0 commit 8ca44d6

File tree

9 files changed

+157
-134
lines changed

9 files changed

+157
-134
lines changed

benchmark/benchmark.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def zkir_benchmark_test(name, mlir_src, test_src, zkir_opt_flags = [], data = []
146146
":" + import_name,
147147
"@google_benchmark//:benchmark_main",
148148
"@googletest//:gtest",
149+
"@llvm-project//mlir:mlir_runner_utils",
149150
"@local_config_omp//:omp",
150151
],
151152
copts = ["-Xclang -fopenmp"],

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void BM_ntt_benchmark(::benchmark::State &state) {
5858
}
5959
}
6060

61-
BENCHMARK(BM_ntt_benchmark)->Unit(::benchmark::kSecond);
61+
BENCHMARK(BM_ntt_benchmark)->Unit(::benchmark::kMillisecond);
6262

6363
void BM_intt_benchmark(::benchmark::State &state) {
6464
Memref<i256> input(1, NUM_COEFFS);
@@ -80,9 +80,7 @@ void BM_intt_benchmark(::benchmark::State &state) {
8080
}
8181
}
8282

83-
// FIXME(batzor): It fails for more than 1 iteration so it seems like it is
84-
// modifying the input. But I am not sure why ;(
85-
BENCHMARK(BM_intt_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
83+
BENCHMARK(BM_intt_benchmark)->Unit(::benchmark::kMillisecond);
8684

8785
void BM_ntt_mont_benchmark(::benchmark::State &state) {
8886
Memref<i256> input(1, NUM_COEFFS);
@@ -104,7 +102,7 @@ void BM_ntt_mont_benchmark(::benchmark::State &state) {
104102
}
105103
}
106104

107-
BENCHMARK(BM_ntt_mont_benchmark)->Unit(::benchmark::kSecond);
105+
BENCHMARK(BM_ntt_mont_benchmark)->Unit(::benchmark::kMillisecond);
108106

109107
void BM_intt_mont_benchmark(::benchmark::State &state) {
110108
Memref<i256> input(1, NUM_COEFFS);
@@ -126,9 +124,7 @@ void BM_intt_mont_benchmark(::benchmark::State &state) {
126124
}
127125
}
128126

129-
// FIXME(batzor): It fails for more than 1 iteration so it seems like it is
130-
// modifying the input. But I am not sure why ;(
131-
BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
127+
BENCHMARK(BM_intt_mont_benchmark)->Unit(::benchmark::kMillisecond);
132128

133129
} // namespace
134130
} // namespace zkir
@@ -140,12 +136,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
140136
// L1 Data 64 KiB
141137
// L1 Instruction 128 KiB
142138
// L2 Unified 4096 KiB (x14)
143-
// Load Average: 1.82, 2.22, 2.39
144-
// ------------------------------------------------------------------------------
145-
// Benchmark Time CPU Iterations
146-
// ------------------------------------------------------------------------------
147-
// BM_ntt_benchmark 10.1 s 10.1 s 1
148-
// BM_intt_benchmark/iterations:1 10.1 s 10.0 s 1
149-
// BM_ntt_mont_benchmark 0.183 s 0.183 s 4
150-
// BM_intt_mont_benchmark/iterations:1 0.266 s 0.214 s 1
139+
// Load Average: 6.49, 5.64, 5.49
140+
// -------------------------------------------------------------------------
141+
// Benchmark Time CPU Iterations
142+
// -------------------------------------------------------------------------
143+
// BM_ntt_benchmark 1656 ms 1050 ms 1
144+
// BM_intt_benchmark/iterations:1 1791 ms 1090 ms 1
145+
// BM_ntt_mont_benchmark 38.6 ms 18.6 ms 40
146+
// BM_intt_mont_benchmark 99.4 ms 56.4 ms 11
151147
// NOLINTEND()

zkir/Dialect/Field/Conversions/FieldToModArith/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"//zkir/Utils:ConversionUtils",
1919
"@llvm-project//llvm:Support",
2020
"@llvm-project//mlir:AffineDialect",
21+
"@llvm-project//mlir:BufferizationDialect",
2122
"@llvm-project//mlir:IR",
2223
"@llvm-project//mlir:LinalgDialect",
2324
"@llvm-project//mlir:Pass",

zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "llvm/Support/Casting.h"
66
#include "mlir/Dialect/Affine/IR/AffineOps.h"
7+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
78
#include "mlir/Dialect/Linalg/IR/Linalg.h"
89
#include "mlir/Dialect/Tensor/IR/Tensor.h"
910
#include "mlir/IR/BuiltinAttributeInterfaces.h"
@@ -55,6 +56,14 @@ class PrimeFieldToModArithTypeConverter : public TypeConverter {
5556
addConversion([](ShapedType type) -> Type {
5657
return convertPrimeFieldLikeType(type);
5758
});
59+
addConversion([](MemRefType type) -> Type {
60+
if (auto primeFieldType =
61+
llvm::dyn_cast<PrimeFieldType>(type.getElementType())) {
62+
return type.cloneWith(type.getShape(),
63+
convertPrimeFieldType(primeFieldType));
64+
}
65+
return type;
66+
});
5867
}
5968
};
6069

@@ -262,22 +271,26 @@ void PrimeFieldToModArith::runOnOperation() {
262271

263272
RewritePatternSet patterns(context);
264273
rewrites::populateWithGenerated(patterns);
265-
patterns
266-
.add<ConvertConstant, ConvertEncapsulate, ConvertExtract, ConvertToMont,
267-
ConvertFromMont, ConvertInverse, ConvertAdd, ConvertSub, ConvertMul,
268-
ConvertMontMul, ConvertAny<affine::AffineForOp>,
269-
ConvertAny<affine::AffineYieldOp>, ConvertAny<linalg::GenericOp>,
270-
ConvertAny<linalg::YieldOp>, ConvertAny<tensor::CastOp>,
271-
ConvertAny<tensor::ExtractOp>, ConvertAny<tensor::FromElementsOp>,
272-
ConvertAny<tensor::InsertOp>>(typeConverter, context);
274+
patterns.add<
275+
ConvertConstant, ConvertEncapsulate, ConvertExtract, ConvertToMont,
276+
ConvertFromMont, ConvertInverse, ConvertAdd, ConvertSub, ConvertMul,
277+
ConvertMontMul, ConvertAny<affine::AffineForOp>,
278+
ConvertAny<affine::AffineParallelOp>, ConvertAny<affine::AffineLoadOp>,
279+
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineYieldOp>,
280+
ConvertAny<linalg::GenericOp>, ConvertAny<linalg::YieldOp>,
281+
ConvertAny<tensor::CastOp>, ConvertAny<tensor::ExtractOp>,
282+
ConvertAny<tensor::FromElementsOp>, ConvertAny<bufferization::ToMemrefOp>,
283+
ConvertAny<bufferization::ToTensorOp>, ConvertAny<tensor::InsertOp>>(
284+
typeConverter, context);
273285

274286
addStructuralConversionPatterns(typeConverter, patterns, target);
275287

276-
target.addDynamicallyLegalOp<affine::AffineForOp, affine::AffineYieldOp,
277-
linalg::GenericOp, linalg::YieldOp,
278-
tensor::CastOp, tensor::ExtractOp,
279-
tensor::FromElementsOp, tensor::InsertOp>(
280-
[&](auto op) { return typeConverter.isLegal(op); });
288+
target.addDynamicallyLegalOp<
289+
affine::AffineForOp, affine::AffineParallelOp, affine::AffineLoadOp,
290+
affine::AffineStoreOp, affine::AffineYieldOp, bufferization::ToMemrefOp,
291+
bufferization::ToTensorOp, linalg::GenericOp, linalg::YieldOp,
292+
tensor::CastOp, tensor::ExtractOp, tensor::FromElementsOp,
293+
tensor::InsertOp>([&](auto op) { return typeConverter.isLegal(op); });
281294

282295
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
283296
signalPassFailure();

zkir/Dialect/ModArith/Conversions/ModArithToArith/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"@llvm-project//llvm:Support",
1919
"@llvm-project//mlir:AffineDialect",
2020
"@llvm-project//mlir:ArithDialect",
21+
"@llvm-project//mlir:BufferizationDialect",
2122
"@llvm-project//mlir:IR",
2223
"@llvm-project//mlir:LinalgDialect",
2324
"@llvm-project//mlir:Pass",

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "llvm/Support/Casting.h"
66
#include "mlir/Dialect/Affine/IR/AffineOps.h"
77
#include "mlir/Dialect/Arith/IR/Arith.h"
8+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
89
#include "mlir/Dialect/Linalg/IR/Linalg.h"
910
#include "mlir/Dialect/SCF/IR/SCF.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -571,18 +572,23 @@ void ModArithToArith::runOnOperation() {
571572
ConvertNegate, ConvertEncapsulate, ConvertExtract, ConvertReduce,
572573
ConvertMontReduce, ConvertToMont, ConvertFromMont, ConvertAdd, ConvertSub,
573574
ConvertMul, ConvertMontMul, ConvertMac, ConvertConstant, ConvertInverse,
574-
ConvertAny<affine::AffineForOp>, ConvertAny<affine::AffineYieldOp>,
575-
ConvertAny<linalg::GenericOp>, ConvertAny<linalg::YieldOp>,
576-
ConvertAny<tensor::CastOp>, ConvertAny<tensor::ExtractOp>,
577-
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::InsertOp>>(
578-
typeConverter, context);
575+
ConvertAny<affine::AffineForOp>, ConvertAny<affine::AffineParallelOp>,
576+
ConvertAny<affine::AffineLoadOp>, ConvertAny<affine::AffineApplyOp>,
577+
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineYieldOp>,
578+
ConvertAny<bufferization::ToMemrefOp>,
579+
ConvertAny<bufferization::ToTensorOp>, ConvertAny<linalg::GenericOp>,
580+
ConvertAny<linalg::YieldOp>, ConvertAny<tensor::CastOp>,
581+
ConvertAny<tensor::ExtractOp>, ConvertAny<tensor::FromElementsOp>,
582+
ConvertAny<tensor::InsertOp>>(typeConverter, context);
579583

580584
addStructuralConversionPatterns(typeConverter, patterns, target);
581585

582-
target.addDynamicallyLegalOp<affine::AffineForOp, affine::AffineYieldOp,
583-
linalg::GenericOp, linalg::YieldOp,
584-
tensor::CastOp, tensor::ExtractOp,
585-
tensor::FromElementsOp, tensor::InsertOp>(
586+
target.addDynamicallyLegalOp<
587+
affine::AffineForOp, affine::AffineParallelOp, affine::AffineLoadOp,
588+
affine::AffineApplyOp, affine::AffineStoreOp, affine::AffineYieldOp,
589+
bufferization::ToMemrefOp, bufferization::ToTensorOp, linalg::GenericOp,
590+
linalg::YieldOp, tensor::CastOp, tensor::ExtractOp,
591+
tensor::FromElementsOp, tensor::InsertOp>(
586592
[&](auto op) { return typeConverter.isLegal(op); });
587593

588594
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {

zkir/Dialect/Poly/Conversions/PolyToField/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
"//zkir/Dialect/Poly/IR:Poly",
1818
"//zkir/Utils:ConversionUtils",
1919
"@llvm-project//mlir:AffineDialect",
20+
"@llvm-project//mlir:BufferizationDialect",
2021
"@llvm-project//mlir:IR",
2122
"@llvm-project//mlir:LinalgDialect",
2223
"@llvm-project//mlir:Pass",

0 commit comments

Comments
 (0)