Skip to content

Commit 6145f33

Browse files
authored
Merge pull request #324 from Xilinx/matthias.onnx_options
Use pass options instead of global variables to drive Decompose
2 parents b55d7c1 + e9c8ad6 commit 6145f33

File tree

6 files changed

+77
-23
lines changed

6 files changed

+77
-23
lines changed

src/Compiler/CompilerPasses.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ void configurePasses() {
6767
}
6868

6969
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
70-
bool donotScrubDisposableElementsAttr,
71-
bool enableQuarkQuantizedLegalization) {
70+
bool donotScrubDisposableElementsAttr, OnnxToMlirOptions opts) {
7271
// This is a transition from previous static passes to full dynamic passes
7372
// Static passes are kept and the dynamic pass is added as IF-THEN
7473
// with the static iteration.
@@ -85,12 +84,14 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
8584
std::make_unique<DisposableGarbageCollector>(pm.getContext()));
8685

8786
// Decompose first. Eliminates some unsupported ops without shape inference.
88-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass());
87+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass(
88+
/*target=*/"", opts.enableConvTransposeDecompose,
89+
opts.enableConvTranposeDecomposeToPhasedConv));
8990
if (!disableRecomposeOption)
9091
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass());
9192
if (enableONNXHybridPass) {
9293
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
93-
!disableRecomposeOption, enableQuarkQuantizedLegalization));
94+
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization));
9495
// Convolution Optimization for CPU: enable when there are no accelerators.
9596
if (targetCPU && enableConvOptPass) {
9697
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
@@ -130,13 +131,13 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
130131

131132
// Simplify shape-related ops.
132133
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
133-
enableQuarkQuantizedLegalization));
134+
opts.enableQuarkQuantizedLegalization));
134135

135136
// One more call to ONNX shape inference/canonicalization/... to update
136137
// shape if possible.
137138
if (enableONNXHybridPass) {
138139
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
139-
!disableRecomposeOption, enableQuarkQuantizedLegalization));
140+
!disableRecomposeOption, opts.enableQuarkQuantizedLegalization));
140141
} else {
141142
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
142143
pm.addPass(mlir::createCanonicalizerPass());
@@ -335,10 +336,15 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
335336

336337
// NOTE: FlexML sets the targetCPU flag to false, as we do not want to run
337338
// the CPU specific transformations.
338-
if (inputIRLevel <= ONNXLevel && emissionTarget >= EmitONNXIR)
339+
if (inputIRLevel <= ONNXLevel && emissionTarget >= EmitONNXIR) {
340+
OnnxToMlirOptions opts;
341+
opts.enableQuarkQuantizedLegalization = enableQuarkQuantizedLegalization;
342+
opts.enableConvTransposeDecompose = enableConvTransposeDecomposeOption;
343+
opts.enableConvTranposeDecomposeToPhasedConv =
344+
enableConvTranposeDecomposeToPhasedConv;
339345
addONNXToMLIRPasses(pm, /*target CPU*/ false,
340-
/*donotScrubDisposableElementsAttr=*/false,
341-
/*enableQuarkQuantizedLegalization=*/enableQuarkQuantizedLegalization);
346+
/*donotScrubDisposableElementsAttr=*/false, opts);
347+
}
342348

343349
if (emissionTarget >= EmitMLIR) {
344350
if (inputIRLevel <= ONNXLevel)

src/Compiler/CompilerPasses.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,25 @@
1414

1515
#ifndef ONNX_MLIR_COMPILER_PASSES_H
1616
#define ONNX_MLIR_COMPILER_PASSES_H
17+
#include "mlir/IR/OwningOpRef.h"
1718
#include "mlir/Pass/PassManager.h"
19+
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
20+
namespace mlir {
21+
class ModuleOp;
22+
}
1823

1924
namespace onnx_mlir {
2025
// Configures passes up front based on command line options.
2126
void configurePasses();
2227

28+
struct OnnxToMlirOptions {
29+
bool enableQuarkQuantizedLegalization = false;
30+
bool enableConvTransposeDecompose = false;
31+
bool enableConvTranposeDecomposeToPhasedConv = false;
32+
};
33+
2334
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
24-
bool donotScrubDisposableElementsAttr = false,
25-
bool enableQuarkQuantizedLegalization = false);
35+
bool donotScrubDisposableElementsAttr = false, OnnxToMlirOptions opts = {});
2636
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
2737
std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
2838
void addKrnlToAffinePasses(mlir::PassManager &pm);

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include "mlir/Support/LogicalResult.h"
3434
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3535

36-
#include "src/Compiler/CompilerOptions.hpp"
3736
#include "src/Dialect/ONNX/DialectBuilder.hpp"
3837
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp"
3938
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,6 +48,10 @@
4948

5049
using namespace mlir;
5150

51+
namespace {
52+
thread_local bool localEnableConvTransposeDecompose = false;
53+
thread_local bool localEnableConvTranposeDecomposeToPhasedConv = false;
54+
} // namespace
5255
namespace onnx_mlir {
5356

5457
// Create an DenseElementsAttr of ArrayAttr.
@@ -659,7 +662,7 @@ Value replaceSequenceAt(
659662
}
660663

661664
bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
662-
if (!onnx_mlir::enableConvTransposeDecomposeOption) {
665+
if (!localEnableConvTransposeDecompose) {
663666
// Disable the ONNXConvTransposeOp decomposition patterns.
664667
return false;
665668
}
@@ -714,7 +717,7 @@ bool hasNoActivationConsumer(Value convTransposeResult) {
714717
bool ShouldDecomposeConvTransposeOpToPhasedConvs(Value convTransposeResult,
715718
ArrayAttr kernelShapeAttr, ArrayAttr padsShapeAttr,
716719
ArrayAttr stridesShapeAttr, ArrayAttr outputShapeAttr) {
717-
if (!onnx_mlir::enableConvTranposeDecomposeToPhasedConv) {
720+
if (!localEnableConvTranposeDecomposeToPhasedConv) {
718721
// Disable the ONNXConvTransposeOp to Conv decomposition patterns.
719722
return false;
720723
}
@@ -2738,11 +2741,23 @@ struct DecomposeONNXToONNXPass
27382741
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
27392742
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeONNXToONNXPass)
27402743

2741-
DecomposeONNXToONNXPass(const std::string &target) { this->target = target; }
2744+
DecomposeONNXToONNXPass(const std::string &target,
2745+
bool enableConvTransposeDecompose = false,
2746+
bool enableConvTranposeDecomposeToPhasedConv = false) {
2747+
this->target = target;
2748+
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
2749+
this->enableConvTranposeDecomposeToPhasedConv =
2750+
enableConvTranposeDecomposeToPhasedConv;
2751+
}
2752+
27422753
DecomposeONNXToONNXPass(const DecomposeONNXToONNXPass &pass)
27432754
: mlir::PassWrapper<DecomposeONNXToONNXPass,
27442755
OperationPass<func::FuncOp>>() {
27452756
this->target = pass.target.getValue();
2757+
this->enableConvTransposeDecompose =
2758+
pass.enableConvTransposeDecompose.getValue();
2759+
this->enableConvTranposeDecomposeToPhasedConv =
2760+
pass.enableConvTranposeDecomposeToPhasedConv.getValue();
27462761
}
27472762

27482763
StringRef getArgument() const override { return "decompose-onnx"; }
@@ -2755,6 +2770,16 @@ struct DecomposeONNXToONNXPass
27552770
Option<std::string> target{*this, "target",
27562771
llvm::cl::desc("Target Dialect to decompose into"), ::llvm::cl::init("")};
27572772

2773+
Option<bool> enableConvTransposeDecompose{*this, "enable-convtranspose",
2774+
llvm::cl::desc("Enable decomposition of ConvTranspose"),
2775+
::llvm::cl::init(false)};
2776+
2777+
Option<bool> enableConvTranposeDecomposeToPhasedConv{*this,
2778+
"enable-convtranspose-phased",
2779+
llvm::cl::desc("Enable decomposition of ONNX ConvTranspose operator to 4 "
2780+
"phased Conv"),
2781+
::llvm::cl::init(false)};
2782+
27582783
void runOnOperation() final;
27592784

27602785
typedef PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>>
@@ -2774,7 +2799,17 @@ void DecomposeONNXToONNXPass::runOnOperation() {
27742799
}
27752800
#endif
27762801

2777-
if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
2802+
// Set thread locals to affect native functions called by .td patterns.
2803+
localEnableConvTransposeDecompose = enableConvTransposeDecompose;
2804+
localEnableConvTranposeDecomposeToPhasedConv =
2805+
enableConvTranposeDecomposeToPhasedConv;
2806+
2807+
auto status = applyPatternsAndFoldGreedily(function, std::move(patterns));
2808+
2809+
localEnableConvTransposeDecompose = false;
2810+
localEnableConvTranposeDecomposeToPhasedConv = false;
2811+
2812+
if (failed(status))
27782813
signalPassFailure();
27792814
}
27802815

@@ -2809,6 +2844,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
28092844
* Create a DecomposeONNX pass.
28102845
*/
28112846
std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass(
2812-
const std::string &target) {
2813-
return std::make_unique<DecomposeONNXToONNXPass>(target);
2814-
}
2847+
const std::string &target, bool enableConvTransposeDecompose,
2848+
bool enableConvTranposeDecomposeToPhasedConv) {
2849+
return std::make_unique<DecomposeONNXToONNXPass>(target,
2850+
enableConvTransposeDecompose, enableConvTranposeDecomposeToPhasedConv);
2851+
}

src/Pass/Passes.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ std::unique_ptr<mlir::Pass> createONNXOpTransformPass(int threshold,
3838

3939
/// Pass for rewriting inside frontend dialect.
4040
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass(
41-
const std::string &target = "");
41+
const std::string &target = "", bool enableConvTransposeDecompose = false,
42+
bool enableConvTranposeDecomposeToPhasedConv = false);
4243
std::unique_ptr<mlir::Pass> createRecomposeONNXToONNXPass(
4344
const std::string &target = "");
4445

@@ -132,4 +133,4 @@ std::unique_ptr<mlir::Pass> createConvertKrnlToLLVMPass(bool verifyInputTensors,
132133
/// Pass for lowering Onnx ops to TOSA dialect
133134
std::unique_ptr<mlir::Pass> createConvertONNXToTOSAPass();
134135
} // namespace onnx_mlir
135-
#endif
136+
#endif

test/mlir/onnx/onnx_decompose_convtranspose.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --enable-convtranspose-decompose %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx=enable-convtranspose %s -split-input-file | FileCheck %s
22
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx %s -split-input-file | FileCheck %s --check-prefix=DISABLED
33

44
// -----

test/mlir/onnx/onnx_decompose_convtranspose_phased_conv.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --enable-convtranspose-decompose-phased-conv %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx=enable-convtranspose-phased %s -split-input-file | FileCheck %s
22

33
func.func @test_convtrans_4phase_kernel_shape_66(%arg0: tensor<1x512x10x16xf32>, %arg1: tensor<512x256x6x6xf32>) -> tensor<1x256x20x32xf32> {
44
%0 = "onnx.Constant" () { value= dense<0.02> : tensor<256xf32>} : ()-> tensor<256xf32>

0 commit comments

Comments
 (0)