|
12 | 12 | // |
13 | 13 | //===----------------------------------------------------------------------===// |
14 | 14 |
|
| 15 | +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
| 16 | +#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" |
| 17 | +#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" |
15 | 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | 20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 21 | +#include "mlir/Dialect/GPU/Transforms/Passes.h" |
18 | 22 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | 24 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 25 | +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 26 | +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
21 | 27 | #include "mlir/ExecutionEngine/JitRunner.h" |
22 | 28 | #include "mlir/ExecutionEngine/OptUtils.h" |
23 | 29 | #include "mlir/Pass/Pass.h" |
@@ -69,13 +75,31 @@ convertMLIRModule(Operation *op, llvm::LLVMContext &context) { |
69 | 75 | return mainModule; |
70 | 76 | } |
71 | 77 |
|
| 78 | +static LogicalResult runMLIRPasses(Operation *module, |
| 79 | + JitRunnerOptions &options) { |
| 80 | + PassManager passManager(module->getContext(), |
| 81 | + module->getName().getStringRef()); |
| 82 | + if (failed(applyPassManagerCLOptions(passManager))) |
| 83 | + return failure(); |
| 84 | + passManager.addPass(createGpuKernelOutliningPass()); |
| 85 | + passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); |
| 86 | + |
| 87 | + OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>(); |
| 88 | + nestedPM.addPass(spirv::createSPIRVLowerABIAttributesPass()); |
| 89 | + nestedPM.addPass(spirv::createSPIRVUpdateVCEPass()); |
| 90 | + passManager.addPass(createLowerHostCodeToLLVMPass()); |
| 91 | + passManager.addPass(createConvertSPIRVToLLVMPass()); |
| 92 | + return passManager.run(module); |
| 93 | +} |
| 94 | + |
72 | 95 | int main(int argc, char **argv) { |
73 | 96 | llvm::InitLLVM y(argc, argv); |
74 | 97 |
|
75 | 98 | llvm::InitializeNativeTarget(); |
76 | 99 | llvm::InitializeNativeTargetAsmPrinter(); |
77 | 100 |
|
78 | 101 | mlir::JitRunnerConfig jitRunnerConfig; |
| 102 | + jitRunnerConfig.mlirTransformer = runMLIRPasses; |
79 | 103 | jitRunnerConfig.llvmModuleBuilder = convertMLIRModule; |
80 | 104 |
|
81 | 105 | mlir::DialectRegistry registry; |
|
0 commit comments