3333#include " mlir/Support/LogicalResult.h"
3434#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3535
36- #include " src/Compiler/CompilerOptions.hpp"
3736#include " src/Dialect/ONNX/DialectBuilder.hpp"
3837#include " src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp"
3938#include " src/Dialect/ONNX/ONNXOps.hpp"
4948
5049using namespace mlir ;
5150
51+ namespace {
52+ thread_local bool localEnableConvTransposeDecompose = false ;
53+ thread_local bool localEnableConvTranposeDecomposeToPhasedConv = false ;
54+ } // namespace
5255namespace onnx_mlir {
5356
5457// Create an DenseElementsAttr of ArrayAttr.
@@ -659,7 +662,7 @@ Value replaceSequenceAt(
659662}
660663
661664bool shouldDecomposeConvTransposeOp (Value convTransposeResult) {
662- if (!onnx_mlir::enableConvTransposeDecomposeOption ) {
665+ if (!localEnableConvTransposeDecompose ) {
663666 // Disable the ONNXConvTransposeOp decomposition patterns.
664667 return false ;
665668 }
@@ -714,7 +717,7 @@ bool hasNoActivationConsumer(Value convTransposeResult) {
714717bool ShouldDecomposeConvTransposeOpToPhasedConvs (Value convTransposeResult,
715718 ArrayAttr kernelShapeAttr, ArrayAttr padsShapeAttr,
716719 ArrayAttr stridesShapeAttr, ArrayAttr outputShapeAttr) {
717- if (!onnx_mlir::enableConvTranposeDecomposeToPhasedConv ) {
720+ if (!localEnableConvTranposeDecomposeToPhasedConv ) {
718721 // Disable the ONNXConvTransposeOp to Conv decomposition patterns.
719722 return false ;
720723 }
@@ -2738,11 +2741,23 @@ struct DecomposeONNXToONNXPass
27382741 : public PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
27392742 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (DecomposeONNXToONNXPass)
27402743
2741- DecomposeONNXToONNXPass (const std::string &target) { this ->target = target; }
2744+ DecomposeONNXToONNXPass (const std::string &target,
2745+ bool enableConvTransposeDecompose = false ,
2746+ bool enableConvTranposeDecomposeToPhasedConv = false ) {
2747+ this ->target = target;
2748+ this ->enableConvTransposeDecompose = enableConvTransposeDecompose;
2749+ this ->enableConvTranposeDecomposeToPhasedConv =
2750+ enableConvTranposeDecomposeToPhasedConv;
2751+ }
2752+
27422753 DecomposeONNXToONNXPass (const DecomposeONNXToONNXPass &pass)
27432754 : mlir::PassWrapper<DecomposeONNXToONNXPass,
27442755 OperationPass<func::FuncOp>>() {
27452756 this ->target = pass.target .getValue ();
2757+ this ->enableConvTransposeDecompose =
2758+ pass.enableConvTransposeDecompose .getValue ();
2759+ this ->enableConvTranposeDecomposeToPhasedConv =
2760+ pass.enableConvTranposeDecomposeToPhasedConv .getValue ();
27462761 }
27472762
27482763 StringRef getArgument () const override { return " decompose-onnx" ; }
@@ -2755,6 +2770,16 @@ struct DecomposeONNXToONNXPass
27552770 Option<std::string> target{*this , " target" ,
27562771 llvm::cl::desc (" Target Dialect to decompose into" ), ::llvm::cl::init (" " )};
27572772
2773+ Option<bool > enableConvTransposeDecompose{*this , " enable-convtranspose" ,
2774+ llvm::cl::desc (" Enable decomposition of ConvTranspose" ),
2775+ ::llvm::cl::init (false )};
2776+
2777+ Option<bool > enableConvTranposeDecomposeToPhasedConv{*this ,
2778+ " enable-convtranspose-phased" ,
2779+ llvm::cl::desc (" Enable decomposition of ONNX ConvTranspose operator to 4 "
2780+ " phased Conv" ),
2781+ ::llvm::cl::init (false )};
2782+
27582783 void runOnOperation () final ;
27592784
27602785 typedef PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>>
@@ -2774,7 +2799,17 @@ void DecomposeONNXToONNXPass::runOnOperation() {
27742799 }
27752800#endif
27762801
2777- if (failed (applyPatternsAndFoldGreedily (function, std::move (patterns))))
2802+ // Set thread locals to affect native functions called by .td patterns.
2803+ localEnableConvTransposeDecompose = enableConvTransposeDecompose;
2804+ localEnableConvTranposeDecomposeToPhasedConv =
2805+ enableConvTranposeDecomposeToPhasedConv;
2806+
2807+ auto status = applyPatternsAndFoldGreedily (function, std::move (patterns));
2808+
2809+ localEnableConvTransposeDecompose = false ;
2810+ localEnableConvTranposeDecomposeToPhasedConv = false ;
2811+
2812+ if (failed (status))
27782813 signalPassFailure ();
27792814}
27802815
@@ -2809,6 +2844,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
28092844 * Create a DecomposeONNX pass.
28102845 */
28112846std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass (
2812- const std::string &target) {
2813- return std::make_unique<DecomposeONNXToONNXPass>(target);
2814- }
2847+ const std::string &target, bool enableConvTransposeDecompose,
2848+ bool enableConvTranposeDecomposeToPhasedConv) {
2849+ return std::make_unique<DecomposeONNXToONNXPass>(target,
2850+ enableConvTransposeDecompose, enableConvTranposeDecomposeToPhasedConv);
2851+ }
0 commit comments