Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "iree/compiler/Modules/IO/Parameters/Transforms/Passes.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
Expand Down Expand Up @@ -89,7 +90,8 @@ void buildGlobalOptExprHoistingPassPipeline(
}

void buildGlobalOptimizationPassPipeline(
OpPassManager &mainPassManager, const TransformOptions &transformOptions) {
OpPassManager &mainPassManager, const TransformOptions &transformOptions,
const PreprocessingOptions &preprocessingOptions) {
// Import parameters before any global optimization passes so that the inlined
// parameters are available for folding.
if (!transformOptions.parameterImportPaths.empty()) {
Expand All @@ -116,7 +118,18 @@ void buildGlobalOptimizationPassPipeline(
.addPredicatedPass(transformOptions.stripAssertions,
IREE::Util::createStripDebugOpsPass)
.addPass(IREE::Util::createOptimizeIntArithmeticPass)
.addPass(createLinalgQuantizedConvToConvPass)
.addPass(createLinalgQuantizedConvToConvPass);

// Add Conv2D to Img2Col conversion after QuantizedConvToConvPass
if (preprocessingOptions
.preprocessingEnableConv2dToImg2colAfterQuantizedConv) {
FunctionLikeNest(mainPassManager)
.addPass(Preprocessing::createConvertConv2DToImg2ColPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass);
}

FunctionLikeNest(mainPassManager)
.addPass(createLinalgQuantizedMatmulToMatmulPass)
.addPass(IREE::Flow::createCanonicalizePass)
.addPass(createRemoveZeroExtentTensorsPass)
Expand Down Expand Up @@ -298,7 +311,10 @@ void registerGlobalOptimizationPipeline() {
"Runs the IREE global optimization transformation pipeline",
[](OpPassManager &passManager,
const TransformOptions &transformOptions) {
buildGlobalOptimizationPassPipeline(passManager, transformOptions);
// Use default preprocessing options for pipeline registration
PreprocessingOptions defaultPreprocessingOptions;
buildGlobalOptimizationPassPipeline(passManager, transformOptions,
defaultPreprocessingOptions);
});
PassPipelineRegistration<TransformOptions>
globalOptimizationConstantHoistingPassPipeline(
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/GlobalOptimization/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <functional>

#include "iree/compiler/Pipelines/Options.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -145,7 +146,8 @@ struct TransformOptions : public PassPipelineOptions<TransformOptions> {
/// We may ultimately break this out separately so creating a syntactic
/// distinction to keep that as an option.
void buildGlobalOptimizationPassPipeline(
OpPassManager &mainPassManager, const TransformOptions &transformOptions);
OpPassManager &mainPassManager, const TransformOptions &transformOptions,
const PreprocessingOptions &preprocessingOptions);

//------------------------------------------------------------------------------
// Wrappers that not use tablegen options.
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ void PreprocessingOptions::bindOptions(OptionsBinder &binder) {
"iree-preprocessing-pdl-spec-filename", preprocessingPDLSpecFilename,
llvm::cl::desc("File name of a PDL spec to use for preprocessing."),
llvm::cl::cat(category));
binder.opt<bool>(
"iree-preprocessing-enable-conv2d-to-img2col-after-quantized-conv",
preprocessingEnableConv2dToImg2colAfterQuantizedConv,
llvm::cl::desc("Enable Conv2d to Img2col + matmul after "
"QuantizedConvToConvPass in global optimization."),
llvm::cl::cat(category));
binder.opt<bool>(
"iree-preprocessing-enable-conv2d-to-img2col",
preprocessingEnableConv2dToImg2col,
llvm::cl::desc("Enable Conv2d to Img2col + matmul in preprocessing."),
llvm::cl::cat(category));
}

void ParameterOptions::bindOptions(OptionsBinder &binder) {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ struct PreprocessingOptions {
std::string preprocessingTransformSpecFilename;
std::string preprocessingPDLSpecFilename;

bool preprocessingEnableConv2dToImg2col = false;
bool preprocessingEnableConv2dToImg2colAfterQuantizedConv = false;

void bindOptions(OptionsBinder &binder);
using FromFlags = OptionsFromFlags<PreprocessingOptions>;
};
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Pipelines/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ void buildIREEPrecompileTransformPassPipeline(
hooks.beforePhase(IREEVMPipelinePhase::GlobalOptimization, passManager);
}
GlobalOptimization::buildGlobalOptimizationPassPipeline(
passManager, globalTransformOptions);
passManager, globalTransformOptions, preprocessingOptions);
if (hooks.afterPhase) {
hooks.afterPhase(IREEVMPipelinePhase::GlobalOptimization, passManager);
}
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Preprocessing/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ static void buildPreprocessingPassPipelineFromCommandLine(
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
if (preprocessingOptions.preprocessingEnableConv2dToImg2col) {
passManager.addPass(createConvertConv2DToImg2ColPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
}

void buildPreprocessingPassPipeline(
Expand Down