Skip to content

Commit c78aca7

Browse files
committed
Merge remote-tracking branch 'origin/feature/onnx-to-tosa' into xiao.add_remove_qdq_aroundop
2 parents edd9f5d + 442d942 commit c78aca7

File tree

7 files changed

+254
-159
lines changed

7 files changed

+254
-159
lines changed

src/Compiler/CMakeLists.txt

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,44 @@ add_onnx_mlir_library(OMCompilerDialects
7474
MLIRIR
7575
MLIROpenMPToLLVMIRTranslation
7676
${dialect_libs}
77-
)
77+
)
78+
79+
add_onnx_mlir_library(OMDisposableGarbageCollector
80+
DisposableGarbageCollector.cpp
81+
82+
EXCLUDE_FROM_OM_LIBS
83+
84+
INCLUDE_DIRS PRIVATE
85+
${FILE_GENERATE_DIR}
86+
87+
INCLUDE_DIRS PUBLIC
88+
${ONNX_MLIR_SRC_ROOT}/include
89+
90+
LINK_LIBS PUBLIC
91+
OMONNXOps
92+
)
93+
94+
add_onnx_mlir_library(OMOnnxToMlirPasses
95+
OnnxToMlirPasses.cpp
96+
97+
EXCLUDE_FROM_OM_LIBS
98+
99+
INCLUDE_DIRS PRIVATE
100+
${FILE_GENERATE_DIR}
101+
102+
INCLUDE_DIRS PUBLIC
103+
${ONNX_MLIR_SRC_ROOT}/include
104+
105+
LINK_LIBS PRIVATE
106+
OMHybridTransform
107+
OMONNXStandardFuncReturnPass
108+
OMONNXSimplifyShapeRelatedOps
109+
OMDisposableGarbageCollector
110+
MLIRFuncDialect
111+
)
78112

79113
add_onnx_mlir_library(OMCompilerPasses
80114
CompilerPasses.cpp
81-
DisposableGarbageCollector.cpp
82115

83116
EXCLUDE_FROM_OM_LIBS
84117

@@ -90,6 +123,7 @@ add_onnx_mlir_library(OMCompilerPasses
90123

91124
LINK_LIBS PUBLIC
92125
${OMLibs}
126+
OMOnnxToMlirPasses
93127
OMCompilerOptions
94128
MLIRAffineTransforms
95129
MLIRBufferizationPipelines
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef ONNX_MLIR_COMPILER_OPTION_ENUMS_H
2+
#define ONNX_MLIR_COMPILER_OPTION_ENUMS_H
3+
4+
#include "src/Accelerators/Accelerator.hpp"
5+
6+
namespace onnx_mlir {
7+
8+
typedef enum {
9+
// clang-format off
10+
None,
11+
Onnx
12+
APPLY_TO_ACCELERATORS(ACCEL_INSTRUMENTSTAGE_ENUM)
13+
// clang-format on
14+
} InstrumentStages;
15+
16+
using ProfileIRs = InstrumentStages;
17+
} // namespace onnx_mlir
18+
19+
#endif // ONNX_MLIR_COMPILER_OPTION_ENUMS_H

src/Compiler/CompilerOptions.hpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define ONNX_MLIR_COMPILER_OPTIONS_H
1717

1818
#include "src/Accelerators/Accelerator.hpp"
19+
#include "src/Compiler/CompilerOptionEnums.hpp"
1920
#include "llvm/Support/CommandLine.h"
2021
#include "llvm/Support/FileSystem.h"
2122
#include "llvm/Support/Path.h"
@@ -34,16 +35,6 @@ extern const std::string OnnxMlirEnvOptionName;
3435

3536
namespace onnx_mlir {
3637

37-
typedef enum {
38-
// clang-format off
39-
None,
40-
Onnx
41-
APPLY_TO_ACCELERATORS(ACCEL_INSTRUMENTSTAGE_ENUM)
42-
// clang-format on
43-
} InstrumentStages;
44-
45-
using ProfileIRs = InstrumentStages;
46-
4738
typedef enum {
4839
// clang-format off
4940
small,

src/Compiler/CompilerPasses.cpp

Lines changed: 16 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include "src/Compiler/CompilerOptions.hpp"
3535
#include "src/Compiler/CompilerPasses.hpp"
36-
#include "src/Compiler/DisposableGarbageCollector.hpp"
36+
#include "src/Compiler/OnnxToMlirPasses.hpp"
3737
#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp"
3838
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"
3939
#include "src/Dialect/ONNX/ONNXDialect.hpp"
@@ -66,141 +66,6 @@ void configurePasses() {
6666
!disableSimdOption);
6767
}
6868

69-
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
70-
bool donotScrubDisposableElementsAttr, OnnxToMlirOptions opts) {
71-
// This is a transition from previous static passes to full dynamic passes
72-
// Static passes are kept and the dynamic pass is added as IF-THEN
73-
// with the static iteration.
74-
// The reasons are
75-
// 1. The debug flag, --print-ir-after/befor-all, can display IR for each
76-
// static pass, but the dynamic pipeline will be viewed as one. MLIR
77-
// may have solution that I am not aware of yet.
78-
// 2. Easy to compare two approaches.
79-
// In future, only the dynamic pass, ONNXOpTransformPass, will be used for
80-
// this function.
81-
82-
if (!donotScrubDisposableElementsAttr)
83-
pm.addInstrumentation(
84-
std::make_unique<DisposableGarbageCollector>(pm.getContext()));
85-
86-
// Decompose first. Eliminates some unsupported ops without shape inference.
87-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass(
88-
/*target=*/"", opts.enableConvTransposeDecompose,
89-
opts.enableConvTransposeDecomposeToPhasedConv,
90-
opts.enableConvTranspose1dDecomposeToPhasedConv));
91-
if (!disableRecomposeOption)
92-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass());
93-
if (enableONNXHybridPass) {
94-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
95-
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
96-
opts.enableConvTransposeDecompose,
97-
opts.enableConvTransposeDecomposeToPhasedConv,
98-
opts.enableConvTranspose1dDecomposeToPhasedConv));
99-
// Convolution Optimization for CPU: enable when there are no accelerators.
100-
if (targetCPU && enableConvOptPass) {
101-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
102-
enableSimdDataLayout && !disableSimdOption));
103-
pm.addNestedPass<func::FuncOp>(
104-
onnx_mlir::createONNXHybridTransformPass(!disableRecomposeOption,
105-
/*enableQuarkQuantizedOpsLegalization=*/false,
106-
opts.enableConvTransposeDecompose,
107-
opts.enableConvTransposeDecomposeToPhasedConv,
108-
opts.enableConvTranspose1dDecomposeToPhasedConv));
109-
}
110-
} else {
111-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
112-
pm.addPass(mlir::createCanonicalizerPass());
113-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
114-
// Convolution Optimization for CPU: enable when there are no accelerators.
115-
if (targetCPU && enableConvOptPass) {
116-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
117-
enableSimdDataLayout && !disableSimdOption));
118-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
119-
}
120-
pm.addNestedPass<func::FuncOp>(
121-
onnx_mlir::createLegalizeQuarkQuantizedOpsPass());
122-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
123-
if (onnxOpTransformThreshold > 0) {
124-
// Dynamic iterate in ONNXOpTransformPass
125-
pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold,
126-
onnxOpTransformReport, targetCPU,
127-
enableSimdDataLayout && !disableSimdOption, enableConvOptPass,
128-
!disableRecomposeOption));
129-
} else {
130-
// Statically add extra passes
131-
for (int i = 0; i < repeatOnnxTransform; i++) {
132-
pm.addPass(mlir::createCanonicalizerPass());
133-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
134-
pm.addNestedPass<func::FuncOp>(
135-
onnx_mlir::createConstPropONNXToONNXPass());
136-
}
137-
}
138-
}
139-
140-
// Simplify shape-related ops.
141-
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
142-
opts.enableQuarkQuantizedLegalization));
143-
144-
// Pass for removing Dq and Q around data movement in Dq->op->Q Ops chain
145-
if (opts.enableRemoveDqQAroundOp)
146-
pm.addPass(createQDQAroundOpOptONNXToONNXPass());
147-
148-
// Pass for removing redundant Dq->Q Ops chain
149-
if (opts.enableRemoveDqQOp)
150-
pm.addPass(createQDQOptONNXToONNXPass());
151-
152-
// One more call to ONNX shape inference/canonicalization/... to update
153-
// shape if possible.
154-
if (enableONNXHybridPass) {
155-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
156-
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
157-
opts.enableConvTransposeDecompose,
158-
opts.enableConvTransposeDecomposeToPhasedConv,
159-
opts.enableConvTranspose1dDecomposeToPhasedConv));
160-
} else {
161-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
162-
pm.addPass(mlir::createCanonicalizerPass());
163-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
164-
}
165-
166-
// Replace ONNXReturnOp with func::ReturnOp.
167-
pm.addPass(onnx_mlir::createStandardFuncReturnPass());
168-
169-
// Clean dead code.
170-
pm.addPass(mlir::createSymbolDCEPass());
171-
172-
// Replace every DisposableElementsAttr with DenseElementsAttr.
173-
if (!donotScrubDisposableElementsAttr)
174-
pm.addPass(createScrubDisposablePass());
175-
176-
// Set onnx_node_name if it is missing. Keep this pass at the end of this
177-
// function and just before instrumentation.
178-
pm.addPass(createSetONNXNodeNamePass());
179-
180-
// Add instrumentation for Onnx Ops
181-
// Keep this pass at the end of this function.
182-
unsigned instrumentActions = instrumentControlBits;
183-
if (profileIR == onnx_mlir::ProfileIRs::Onnx) {
184-
instrumentStage = onnx_mlir::InstrumentStages::Onnx;
185-
instrumentOps = "onnx.*";
186-
// Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp
187-
// and InstrumentReportTime. Disable the last bit for
188-
// InstrumentReportMemory because of its big overhead. Users can
189-
// optionally enable the last bit by using
190-
// --InstrumentReportMemory option.
191-
instrumentActions |= (1 << 3) - 1;
192-
}
193-
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx)
194-
pm.addNestedPass<func::FuncOp>(
195-
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
196-
// Print Signatures of each op at runtime if enabled. Should not run
197-
// signature and instrument passes at the same time as time may include printf
198-
// overheads.
199-
if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE")
200-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
201-
instrumentSignatures, instrumentOnnxNode));
202-
}
203-
20469
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
20570
std::string ONNXOpsStatFormat) {
20671
if (enableCSE)
@@ -363,6 +228,21 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
363228
enableConvTransposeDecomposeToPhasedConv;
364229
opts.enableConvTranspose1dDecomposeToPhasedConv =
365230
enableConvTranspose1dDecomposeToPhasedConv;
231+
opts.disableRecomposeOption = disableRecomposeOption;
232+
opts.enableONNXHybridPass = enableONNXHybridPass;
233+
opts.enableConvOptPass = enableConvOptPass;
234+
opts.enableSimdDataLayout = enableSimdDataLayout;
235+
opts.disableSimdOption = disableSimdOption;
236+
opts.onnxOpTransformThreshold = onnxOpTransformThreshold;
237+
opts.onnxOpTransformReport = onnxOpTransformReport;
238+
opts.repeatOnnxTransform = repeatOnnxTransform;
239+
opts.instrumentControlBits = instrumentControlBits;
240+
opts.instrumentOps = instrumentOps;
241+
opts.instrumentSignatures = instrumentSignatures;
242+
opts.instrumentOnnxNode = instrumentOnnxNode;
243+
opts.profileIR = profileIR;
244+
opts.instrumentStage = instrumentStage;
245+
366246
addONNXToMLIRPasses(pm, /*target CPU*/ false,
367247
/*donotScrubDisposableElementsAttr=*/false, opts);
368248
}

src/Compiler/CompilerPasses.hpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,6 @@ namespace onnx_mlir {
2525
// Configures passes up front based on command line options.
2626
void configurePasses();
2727

28-
struct OnnxToMlirOptions {
29-
bool enableQuarkQuantizedLegalization = false;
30-
bool enableConvTransposeDecompose = false;
31-
bool enableConvTransposeDecomposeToPhasedConv = false;
32-
bool enableConvTranspose1dDecomposeToPhasedConv = false;
33-
bool enableRemoveDqQAroundOp = true;
34-
bool enableRemoveDqQOp = true;
35-
};
36-
37-
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
38-
bool donotScrubDisposableElementsAttr = false, OnnxToMlirOptions opts = {});
3928
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
4029
std::string ONNXOpsStatFilename);
4130
void addKrnlToAffinePasses(mlir::PassManager &pm);

0 commit comments

Comments
 (0)