|
33 | 33 |
|
34 | 34 | #include "src/Compiler/CompilerOptions.hpp" |
35 | 35 | #include "src/Compiler/CompilerPasses.hpp" |
36 | | -#include "src/Compiler/DisposableGarbageCollector.hpp" |
| 36 | +#include "src/Compiler/OnnxToMlirPasses.hpp" |
37 | 37 | #include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" |
38 | 38 | #include "src/Dialect/Mlir/VectorMachineSupport.hpp" |
39 | 39 | #include "src/Dialect/ONNX/ONNXDialect.hpp" |
@@ -66,137 +66,6 @@ void configurePasses() { |
66 | 66 | !disableSimdOption); |
67 | 67 | } |
68 | 68 |
|
69 | | -void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, |
70 | | - bool donotScrubDisposableElementsAttr, OnnxToMlirOptions opts) { |
71 | | - // This is a transition from previous static passes to full dynamic passes |
72 | | - // Static passes are kept and the dynamic pass is added as IF-THEN |
73 | | - // with the static iteration. |
74 | | - // The reasons are |
75 | | - // 1. The debug flag, --print-ir-after/befor-all, can display IR for each |
76 | | - // static pass, but the dynamic pipeline will be viewed as one. MLIR |
77 | | - // may have solution that I am not aware of yet. |
78 | | - // 2. Easy to compare two approaches. |
79 | | - // In future, only the dynamic pass, ONNXOpTransformPass, will be used for |
80 | | - // this function. |
81 | | - |
82 | | - if (!donotScrubDisposableElementsAttr) |
83 | | - pm.addInstrumentation( |
84 | | - std::make_unique<DisposableGarbageCollector>(pm.getContext())); |
85 | | - |
86 | | - // Decompose first. Eliminates some unsupported ops without shape inference. |
87 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass( |
88 | | - /*target=*/"", opts.enableConvTransposeDecompose, |
89 | | - opts.enableConvTransposeDecomposeToPhasedConv, |
90 | | - opts.enableConvTranspose1dDecomposeToPhasedConv)); |
91 | | - if (!disableRecomposeOption) |
92 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass()); |
93 | | - if (enableONNXHybridPass) { |
94 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass( |
95 | | - !disableRecomposeOption, opts.enableQuarkQuantizedLegalization, |
96 | | - opts.enableConvTransposeDecompose, |
97 | | - opts.enableConvTransposeDecomposeToPhasedConv, |
98 | | - opts.enableConvTranspose1dDecomposeToPhasedConv)); |
99 | | - // Convolution Optimization for CPU: enable when there are no accelerators. |
100 | | - if (targetCPU && enableConvOptPass) { |
101 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass( |
102 | | - enableSimdDataLayout && !disableSimdOption)); |
103 | | - pm.addNestedPass<func::FuncOp>( |
104 | | - onnx_mlir::createONNXHybridTransformPass(!disableRecomposeOption, |
105 | | - /*enableQuarkQuantizedOpsLegalization=*/false, |
106 | | - opts.enableConvTransposeDecompose, |
107 | | - opts.enableConvTransposeDecomposeToPhasedConv, |
108 | | - opts.enableConvTranspose1dDecomposeToPhasedConv)); |
109 | | - } |
110 | | - } else { |
111 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
112 | | - pm.addPass(mlir::createCanonicalizerPass()); |
113 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
114 | | - // Convolution Optimization for CPU: enable when there are no accelerators. |
115 | | - if (targetCPU && enableConvOptPass) { |
116 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass( |
117 | | - enableSimdDataLayout && !disableSimdOption)); |
118 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
119 | | - } |
120 | | - pm.addNestedPass<func::FuncOp>( |
121 | | - onnx_mlir::createLegalizeQuarkQuantizedOpsPass()); |
122 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass()); |
123 | | - if (onnxOpTransformThreshold > 0) { |
124 | | - // Dynamic iterate in ONNXOpTransformPass |
125 | | - pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold, |
126 | | - onnxOpTransformReport, targetCPU, |
127 | | - enableSimdDataLayout && !disableSimdOption, enableConvOptPass, |
128 | | - !disableRecomposeOption)); |
129 | | - } else { |
130 | | - // Statically add extra passes |
131 | | - for (int i = 0; i < repeatOnnxTransform; i++) { |
132 | | - pm.addPass(mlir::createCanonicalizerPass()); |
133 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
134 | | - pm.addNestedPass<func::FuncOp>( |
135 | | - onnx_mlir::createConstPropONNXToONNXPass()); |
136 | | - } |
137 | | - } |
138 | | - } |
139 | | - |
140 | | - // Simplify shape-related ops. |
141 | | - pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass( |
142 | | - opts.enableQuarkQuantizedLegalization)); |
143 | | - |
144 | | - // Passes for removing redundant concat, slice and cast QDQ Ops |
145 | | - if (opts.enableRemoveDqQOp) |
146 | | - pm.addPass(createQDQOptONNXToONNXPass()); |
147 | | - |
148 | | - // One more call to ONNX shape inference/canonicalization/... to update |
149 | | - // shape if possible. |
150 | | - if (enableONNXHybridPass) { |
151 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass( |
152 | | - !disableRecomposeOption, opts.enableQuarkQuantizedLegalization, |
153 | | - opts.enableConvTransposeDecompose, |
154 | | - opts.enableConvTransposeDecomposeToPhasedConv, |
155 | | - opts.enableConvTranspose1dDecomposeToPhasedConv)); |
156 | | - } else { |
157 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
158 | | - pm.addPass(mlir::createCanonicalizerPass()); |
159 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass()); |
160 | | - } |
161 | | - |
162 | | - // Replace ONNXReturnOp with func::ReturnOp. |
163 | | - pm.addPass(onnx_mlir::createStandardFuncReturnPass()); |
164 | | - |
165 | | - // Clean dead code. |
166 | | - pm.addPass(mlir::createSymbolDCEPass()); |
167 | | - |
168 | | - // Replace every DisposableElementsAttr with DenseElementsAttr. |
169 | | - if (!donotScrubDisposableElementsAttr) |
170 | | - pm.addPass(createScrubDisposablePass()); |
171 | | - |
172 | | - // Set onnx_node_name if it is missing. Keep this pass at the end of this |
173 | | - // function and just before instrumentation. |
174 | | - pm.addPass(createSetONNXNodeNamePass()); |
175 | | - |
176 | | - // Add instrumentation for Onnx Ops |
177 | | - // Keep this pass at the end of this function. |
178 | | - unsigned instrumentActions = instrumentControlBits; |
179 | | - if (profileIR == onnx_mlir::ProfileIRs::Onnx) { |
180 | | - instrumentStage = onnx_mlir::InstrumentStages::Onnx; |
181 | | - instrumentOps = "onnx.*"; |
182 | | - // Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp |
183 | | - // and InstrumentReportTime. Disable the last bit for |
184 | | - // InstrumentReportMemory because of its big overhead. Users can |
185 | | - // optionally enable the last bit by using |
186 | | - // --InstrumentReportMemory option. |
187 | | - instrumentActions |= (1 << 3) - 1; |
188 | | - } |
189 | | - if (instrumentStage == onnx_mlir::InstrumentStages::Onnx) |
190 | | - pm.addNestedPass<func::FuncOp>( |
191 | | - onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); |
192 | | - // Print Signatures of each op at runtime if enabled. Should not run |
193 | | - // signature and instrument passes at the same time as time may include printf |
194 | | - // overheads. |
195 | | - if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE") |
196 | | - pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass( |
197 | | - instrumentSignatures, instrumentOnnxNode)); |
198 | | -} |
199 | | - |
200 | 69 | void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, |
201 | 70 | std::string ONNXOpsStatFormat) { |
202 | 71 | if (enableCSE) |
@@ -359,6 +228,21 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm, |
359 | 228 | enableConvTransposeDecomposeToPhasedConv; |
360 | 229 | opts.enableConvTranspose1dDecomposeToPhasedConv = |
361 | 230 | enableConvTranspose1dDecomposeToPhasedConv; |
| 231 | + opts.disableRecomposeOption = disableRecomposeOption; |
| 232 | + opts.enableONNXHybridPass = enableONNXHybridPass; |
| 233 | + opts.enableConvOptPass = enableConvOptPass; |
| 234 | + opts.enableSimdDataLayout = enableSimdDataLayout; |
| 235 | + opts.disableSimdOption = disableSimdOption; |
| 236 | + opts.onnxOpTransformThreshold = onnxOpTransformThreshold; |
| 237 | + opts.onnxOpTransformReport = onnxOpTransformReport; |
| 238 | + opts.repeatOnnxTransform = repeatOnnxTransform; |
| 239 | + opts.instrumentControlBits = instrumentControlBits; |
| 240 | + opts.instrumentOps = instrumentOps; |
| 241 | + opts.instrumentSignatures = instrumentSignatures; |
| 242 | + opts.instrumentOnnxNode = instrumentOnnxNode; |
| 243 | + opts.profileIR = profileIR; |
| 244 | + opts.instrumentStage = instrumentStage; |
| 245 | + |
362 | 246 | addONNXToMLIRPasses(pm, /*target CPU*/ false, |
363 | 247 | /*donotScrubDisposableElementsAttr=*/false, opts); |
364 | 248 | } |
|
0 commit comments