diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index dcb66a43ca4e..381c4c27db46 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h" +#include "iree/compiler/Preprocessing/Common/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" @@ -89,7 +90,8 @@ void buildGlobalOptExprHoistingPassPipeline( } void buildGlobalOptimizationPassPipeline( - OpPassManager &mainPassManager, const TransformOptions &transformOptions) { + OpPassManager &mainPassManager, const TransformOptions &transformOptions, + const PreprocessingOptions &preprocessingOptions) { // Import parameters before any global optimization passes so that the inlined // parameters are available for folding. if (!transformOptions.parameterImportPaths.empty()) { @@ -116,7 +118,18 @@ void buildGlobalOptimizationPassPipeline( .addPredicatedPass(transformOptions.stripAssertions, IREE::Util::createStripDebugOpsPass) .addPass(IREE::Util::createOptimizeIntArithmeticPass) - .addPass(createLinalgQuantizedConvToConvPass) + .addPass(createLinalgQuantizedConvToConvPass); + + // Add Conv2D to Img2Col conversion after QuantizedConvToConvPass + if (preprocessingOptions + .preprocessingEnableConv2dToImg2colAfterQuantizedConv) { + FunctionLikeNest(mainPassManager) + .addPass(Preprocessing::createConvertConv2DToImg2ColPass) + .addPass(createCanonicalizerPass) + .addPass(createCSEPass); + } + + FunctionLikeNest(mainPassManager) .addPass(createLinalgQuantizedMatmulToMatmulPass) .addPass(IREE::Flow::createCanonicalizePass) .addPass(createRemoveZeroExtentTensorsPass) @@ -298,7 +311,10 @@ void registerGlobalOptimizationPipeline() { "Runs the IREE global optimization transformation pipeline", [](OpPassManager &passManager, const TransformOptions &transformOptions) { - buildGlobalOptimizationPassPipeline(passManager, transformOptions); + // Use default preprocessing options for pipeline registration + PreprocessingOptions defaultPreprocessingOptions; + buildGlobalOptimizationPassPipeline(passManager, transformOptions, + defaultPreprocessingOptions); }); PassPipelineRegistration globalOptimizationConstantHoistingPassPipeline( diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index d1e1c925c1b7..b65cc45e1e37 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -9,6 +9,7 @@ #include +#include "iree/compiler/Pipelines/Options.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" @@ -145,7 +146,8 @@ struct TransformOptions : public PassPipelineOptions { /// We may ultimately break this out separately so creating a syntactic /// distinction to keep that as an option. void buildGlobalOptimizationPassPipeline( - OpPassManager &mainPassManager, const TransformOptions &transformOptions); + OpPassManager &mainPassManager, const TransformOptions &transformOptions, + const PreprocessingOptions &preprocessingOptions); //------------------------------------------------------------------------------ // Wrappers that not use tablegen options. diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 25bf0df2e3d0..42a86d85cae0 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -160,6 +160,17 @@ void PreprocessingOptions::bindOptions(OptionsBinder &binder) { "iree-preprocessing-pdl-spec-filename", preprocessingPDLSpecFilename, llvm::cl::desc("File name of a PDL spec to use for preprocessing."), llvm::cl::cat(category)); + binder.opt( + "iree-preprocessing-enable-conv2d-to-img2col-after-quantized-conv", + preprocessingEnableConv2dToImg2colAfterQuantizedConv, + llvm::cl::desc("Enable Conv2d to Img2col + matmul after " + "QuantizedConvToConvPass in global optimization."), + llvm::cl::cat(category)); + binder.opt( + "iree-preprocessing-enable-conv2d-to-img2col", + preprocessingEnableConv2dToImg2col, + llvm::cl::desc("Enable Conv2d to Img2col + matmul in preprocessing."), + llvm::cl::cat(category)); } void ParameterOptions::bindOptions(OptionsBinder &binder) { diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 811bd5b17f1a..43173417feca 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -101,6 +101,9 @@ struct PreprocessingOptions { std::string preprocessingTransformSpecFilename; std::string preprocessingPDLSpecFilename; + bool preprocessingEnableConv2dToImg2col = false; + bool preprocessingEnableConv2dToImg2colAfterQuantizedConv = false; + void bindOptions(OptionsBinder &binder); using FromFlags = OptionsFromFlags; }; diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index c8b1cace4003..869cfff83dd8 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -270,7 +270,7 @@ void buildIREEPrecompileTransformPassPipeline( hooks.beforePhase(IREEVMPipelinePhase::GlobalOptimization, passManager); } GlobalOptimization::buildGlobalOptimizationPassPipeline( - passManager, globalTransformOptions); + passManager, globalTransformOptions, preprocessingOptions); if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::GlobalOptimization, passManager); } diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp index 6ef6933803a1..149e0edbb894 100644 --- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp @@ -92,6 +92,11 @@ static void buildPreprocessingPassPipelineFromCommandLine( passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); } + if (preprocessingOptions.preprocessingEnableConv2dToImg2col) { + passManager.addPass(createConvertConv2DToImg2ColPass()); + passManager.addPass(createCanonicalizerPass()); + passManager.addPass(createCSEPass()); + } } void buildPreprocessingPassPipeline(