Skip to content

Commit 6b8f5b0

Browse files
committed
feat: lower to omp dialect
1 parent abdb706 commit 6b8f5b0

File tree

7 files changed

+35
-8
lines changed

7 files changed

+35
-8
lines changed

benchmark/benchmark.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def zkir_benchmark_test(name, mlir_src, test_src, zkir_opt_flags = [], data = []
146146
":" + import_name,
147147
"@google_benchmark//:benchmark_main",
148148
"@googletest//:gtest",
149+
"@local_config_omp//:omp",
149150
],
151+
copts = ["-Xclang -fopenmp"],
152+
linkopts = ["-Xclang -fopenmp"],
150153
tags = tags,
151154
data = data + [generated_obj_name],
152155
**kwargs

benchmark/ntt/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ zkir_benchmark_test(
55
mlir_src = "ntt_benchmark.mlir",
66
tags = ["manual"],
77
test_src = ["ntt_benchmark_test.cc"],
8-
zkir_opt_flags = ["-poly-to-llvm"],
8+
zkir_opt_flags = ["-poly-to-omp"],
99
deps = [
1010
"//benchmark:BenchmarkUtils",
1111
],

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
140140
// L1 Data 64 KiB
141141
// L1 Instruction 128 KiB
142142
// L2 Unified 4096 KiB (x14)
143-
// Load Average: 3.97, 3.11, 2.96
143+
// Load Average: 1.82, 2.22, 2.39
144144
// ------------------------------------------------------------------------------
145145
// Benchmark Time CPU Iterations
146146
// ------------------------------------------------------------------------------
147-
// BM_ntt_benchmark 10.2 s 10.1 s 1
148-
// BM_intt_benchmark/iterations:1 11.1 s 11.1 s 1
149-
// BM_ntt_mont_benchmark 0.190 s 0.190 s 3
150-
// BM_intt_mont_benchmark/iterations:1 0.316 s 0.304 s 1
147+
// BM_ntt_benchmark 10.1 s 10.1 s 1
148+
// BM_intt_benchmark/iterations:1 10.1 s 10.0 s 1
149+
// BM_ntt_mont_benchmark 0.183 s 0.183 s 4
150+
// BM_intt_mont_benchmark/iterations:1 0.266 s 0.214 s 1
151151
// NOLINTEND()

tools/zkir-opt.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ int main(int argc, char **argv) {
3232

3333
mlir::PassPipelineRegistration<>(
3434
"poly-to-llvm", "Run passes to lower the polynomial dialect to LLVM",
35-
mlir::zkir::pipelines::polyToLLVMPipelineBuilder);
35+
mlir::zkir::pipelines::polyToLLVMPipelineBuilder<false>);
36+
mlir::PassPipelineRegistration<>(
37+
"poly-to-omp",
38+
"Run passes to lower the polynomial dialect to OpenMP + LLVM",
39+
mlir::zkir::pipelines::polyToLLVMPipelineBuilder<true>);
3640

3741
return failed(mlir::MlirOptMain(argc, argv, "ZKIR optimizer\n", registry));
3842
}

zkir/Pipelines/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ cc_library(
2121
"@llvm-project//mlir:FuncDialect",
2222
"@llvm-project//mlir:LinalgTransforms",
2323
"@llvm-project//mlir:MemRefTransforms",
24+
"@llvm-project//mlir:OpenMPToLLVM",
2425
"@llvm-project//mlir:Pass",
2526
"@llvm-project//mlir:SCFToControlFlow",
27+
"@llvm-project//mlir:SCFToOpenMP",
2628
"@llvm-project//mlir:Transforms",
2729
],
2830
)

zkir/Pipelines/PipelineRegistration.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
44
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
55
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
6+
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
67
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
8+
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
79
#include "mlir/Dialect/Affine/Passes.h"
810
#include "mlir/Dialect/Arith/Transforms/Passes.h"
911
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
@@ -41,6 +43,7 @@ void oneShotBufferize(OpPassManager &manager) {
4143
manager.addPass(createCanonicalizerPass());
4244
}
4345

46+
template <bool allowOpenMP>
4447
void polyToLLVMPipelineBuilder(OpPassManager &manager) {
4548
manager.addPass(zkir::poly::createPolyToField());
4649
manager.addPass(zkir::field::createPrimeFieldToModArith());
@@ -49,9 +52,11 @@ void polyToLLVMPipelineBuilder(OpPassManager &manager) {
4952

5053
// Linalg
5154
manager.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
55+
manager.addNestedPass<FuncOp>(createLinalgElementwiseOpFusionPass());
5256
// Needed to lower affine.map and affine.apply
5357
manager.addNestedPass<FuncOp>(affine::createAffineExpandIndexOpsPass());
5458
manager.addNestedPass<FuncOp>(affine::createSimplifyAffineStructuresPass());
59+
manager.addPass(affine::createAffineParallelize());
5560
manager.addPass(createLowerAffinePass());
5661
manager.addNestedPass<FuncOp>(memref::createExpandOpsPass());
5762
manager.addNestedPass<FuncOp>(memref::createExpandStridedMetadataPass());
@@ -61,7 +66,7 @@ void polyToLLVMPipelineBuilder(OpPassManager &manager) {
6166

6267
// Linalg must be bufferized before it can be lowered
6368
// But lowering to loops also re-introduces affine.apply, so re-lower that
64-
manager.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
69+
manager.addNestedPass<FuncOp>(createConvertLinalgToParallelLoopsPass());
6570
manager.addPass(createLowerAffinePass());
6671
manager.addPass(createBufferizationToMemRefPass());
6772

@@ -73,13 +78,17 @@ void polyToLLVMPipelineBuilder(OpPassManager &manager) {
7378

7479
// ToLLVM
7580
manager.addPass(arith::createArithExpandOpsPass());
81+
if constexpr (allowOpenMP) {
82+
manager.addPass(createConvertSCFToOpenMPPass());
83+
}
7684
manager.addPass(createConvertSCFToCFPass());
7785
manager.addNestedPass<FuncOp>(memref::createExpandStridedMetadataPass());
7886

7987
// expand strided metadata will create affine map. Needed to lower affine.map
8088
// and affine.apply
8189
manager.addNestedPass<FuncOp>(affine::createAffineExpandIndexOpsPass());
8290
manager.addNestedPass<FuncOp>(affine::createSimplifyAffineStructuresPass());
91+
manager.addPass(affine::createAffineParallelize());
8392
manager.addPass(createLowerAffinePass());
8493
manager.addPass(createConvertToLLVMPass());
8594

@@ -90,4 +99,7 @@ void polyToLLVMPipelineBuilder(OpPassManager &manager) {
9099
manager.addPass(createSymbolDCEPass());
91100
}
92101

102+
template void polyToLLVMPipelineBuilder<false>(mlir::OpPassManager &manager);
103+
template void polyToLLVMPipelineBuilder<true>(mlir::OpPassManager &manager);
104+
93105
} // namespace mlir::zkir::pipelines

zkir/Pipelines/PipelineRegistration.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ namespace mlir::zkir::pipelines {
99

1010
void oneShotBufferize(OpPassManager &manager);
1111

12+
template <bool allowOpenMP>
1213
void polyToLLVMPipelineBuilder(OpPassManager &manager);
1314

15+
extern template void polyToLLVMPipelineBuilder<false>(
16+
mlir::OpPassManager &manager);
17+
extern template void polyToLLVMPipelineBuilder<true>(
18+
mlir::OpPassManager &manager);
19+
1420
} // namespace mlir::zkir::pipelines
1521

1622
#endif // ZKIR_PIPELINES_PIPELINEREGISTRATION_H_

0 commit comments

Comments
 (0)