Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/DisposableGarbageCollector.hpp"
#include "src/Compiler/OnnxToMlirPasses.hpp"
#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp"
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"
Expand Down Expand Up @@ -66,6 +67,139 @@ void configurePasses() {
!disableSimdOption);
}

void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr, OnnxToMlirOptions opts) {
// This is a transition from previous static passes to full dynamic passes
// Static passes are kept and the dynamic pass is added as IF-THEN
// with the static iteration.
// The reasons are
// 1. The debug flag, --print-ir-after/befor-all, can display IR for each
// static pass, but the dynamic pipeline will be viewed as one. MLIR
// may have solution that I am not aware of yet.
// 2. Easy to compare two approaches.
// In future, only the dynamic pass, ONNXOpTransformPass, will be used for
// this function.

if (!donotScrubDisposableElementsAttr)
pm.addInstrumentation(
std::make_unique<DisposableGarbageCollector>(pm.getContext()));

// Decompose first. Eliminates some unsupported ops without shape inference.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass(
/*target=*/"", opts.enableConvTransposeDecompose,
opts.enableConvTransposeDecomposeToPhasedConv,
opts.enableConvTranspose1dDecomposeToPhasedConv));
if (!disableRecomposeOption)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass());
if (enableONNXHybridPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
opts.enableConvTransposeDecompose,
opts.enableConvTransposeDecomposeToPhasedConv,
opts.enableConvTranspose1dDecomposeToPhasedConv));
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
enableSimdDataLayout && !disableSimdOption));
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createONNXHybridTransformPass(!disableRecomposeOption,
/*enableQuarkQuantizedOpsLegalization=*/false,
opts.enableConvTransposeDecompose,
opts.enableConvTransposeDecomposeToPhasedConv,
opts.enableConvTranspose1dDecomposeToPhasedConv));
}
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
enableSimdDataLayout && !disableSimdOption));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createLegalizeQuarkQuantizedOpsPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
if (onnxOpTransformThreshold > 0) {
// Dynamic iterate in ONNXOpTransformPass
pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold,
onnxOpTransformReport, targetCPU,
enableSimdDataLayout && !disableSimdOption, enableConvOptPass,
!disableRecomposeOption));
} else {
// Statically add extra passes
for (int i = 0; i < repeatOnnxTransform; i++) {
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
}
}
}

// Simplify shape-related ops.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
opts.enableQuarkQuantizedLegalization));

// Passes for removing redundant concat, slice and cast QDQ Ops
if (opts.enableRemoveDqQOp)
pm.addPass(createQDQOptONNXToONNXPass());
if (opts.enableRemoveBinary)
pm.addPass(createFoldDQBinaryQPass());

// One more call to ONNX shape inference/canonicalization/... to update
// shape if possible.
if (enableONNXHybridPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
opts.enableConvTransposeDecompose,
opts.enableConvTransposeDecomposeToPhasedConv,
opts.enableConvTranspose1dDecomposeToPhasedConv));
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}

// Replace ONNXReturnOp with func::ReturnOp.
pm.addPass(onnx_mlir::createStandardFuncReturnPass());

// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());

// Replace every DisposableElementsAttr with DenseElementsAttr.
if (!donotScrubDisposableElementsAttr)
pm.addPass(createScrubDisposablePass());

// Set onnx_node_name if it is missing. Keep this pass at the end of this
// function and just before instrumentation.
pm.addPass(createSetONNXNodeNamePass());

// Add instrumentation for Onnx Ops
// Keep this pass at the end of this function.
unsigned instrumentActions = instrumentControlBits;
if (profileIR == onnx_mlir::ProfileIRs::Onnx) {
instrumentStage = onnx_mlir::InstrumentStages::Onnx;
instrumentOps = "onnx.*";
// Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp
// and InstrumentReportTime. Disable the last bit for
// InstrumentReportMemory because of its big overhead. Users can
// optionally enable the last bit by using
// --InstrumentReportMemory option.
instrumentActions |= (1 << 3) - 1;
}
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time as time may include printf
// overheads.
if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE")
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatures, instrumentOnnxNode));
}

void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string ONNXOpsStatFormat) {
if (enableCSE)
Expand Down
13 changes: 13 additions & 0 deletions src/Compiler/CompilerPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ namespace onnx_mlir {
// Configures passes up front based on command line options.
void configurePasses();

/*
struct OnnxToMlirOptions {
bool enableQuarkQuantizedLegalization = false;
bool enableConvTransposeDecompose = false;
bool enableConvTransposeDecomposeToPhasedConv = false;
bool enableConvTranspose1dDecomposeToPhasedConv = false;
bool enableRemoveDqQOp = false;
bool enableRemoveBinary = false;
};

void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr = false, OnnxToMlirOptions opts = {});
*/
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/OnnxToMlirPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct OnnxToMlirOptions {
bool enableConvTranspose1dDecomposeToPhasedConv = false;
bool enableRemoveDqQOp = true;
bool enableRemoveDqQAroundOp = true;
bool enableRemoveBinary = true;

bool disableRecomposeOption = false;
bool enableONNXHybridPass = true;
Expand Down
7 changes: 7 additions & 0 deletions src/Dialect/ONNX/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ add_onnx_mlir_rewriter(DecomposeConvTranspose1dPhased)
add_onnx_mlir_rewriter(ConstProp)
add_onnx_mlir_rewriter(ConvOpt)

add_onnx_mlir_rewriter(QDQAroundOpOpt)
add_onnx_mlir_rewriter(QDQOpt)
add_onnx_mlir_rewriter(DQBinaryQOpt)



add_onnx_mlir_library(OMShapeInference
ShapeInference.cpp

Expand Down Expand Up @@ -44,6 +50,7 @@ add_onnx_mlir_library(OMONNXRewrite
ConstProp.cpp
QDQAroundOpOpt.cpp
QDQOpt.cpp
DQBinaryQOpt.cpp
ConvOpt.cpp
Decompose.cpp
DecomposeEinsum.cpp
Expand Down
Loading