1010#include " torch-mlir/Dialect/Torch/Transforms/Passes.h"
1111#include " mlir/Pass/PassManager.h"
1212#include " mlir/Transforms/Passes.h"
13+ #include " torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
1314
1415void mlir::torch::registerTorchPasses () {
1516 mlir::torch::registerPasses ();
@@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() {
2526 " torch-function-to-torch-backend-pipeline" ,
2627 " Pipeline lowering a Torch function to Torch backend form." ,
2728 mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
29+ mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
30+ " torch-onnx-to-torch-backend-pipeline" ,
31+ " Pipeline lowering Torch Onnx IR to Torch backend form." ,
32+ mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline);
2833 mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
2934 " torch-simplification-pipeline" ,
3035 " Pipeline simplifying computations in the program." ,
@@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
8691 options.backendLegalOps , options.extraLibrary ));
8792}
8893
94+ void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline (
95+ OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
96+ pm.addNestedPass <func::FuncOp>(onnx_c::createTorchOnnxToTorchPass ());
97+ // The above pass just converts the torch onnx IR to torch, hence the given
98+ // pipeline will make sure that the IR is transformed such that it satisfies
99+ // the backend contract.
100+ if (options.decompose ) {
101+ pm.addNestedPass <func::FuncOp>(
102+ Torch::createDecomposeComplexOpsPass (options.backendLegalOps ));
103+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
104+ }
105+ // TODO: Move the combination of two passes i.e., ScalarizeShapes and
106+ // TorchShapeRefinementPipeline out of here and create an onnx shape
107+ // refinement pipeline which runs iteratively over the IR.
108+ createTorchShapeRefinementPipeline (pm, options);
109+ // This pass scalarizes the tensor shape computations.
110+ pm.addNestedPass <mlir::func::FuncOp>(
111+ mlir::torch::Torch::createScalarizeShapesPass ());
112+ createTorchShapeRefinementPipeline (pm, options);
113+ pm.addPass (Torch::createRefinePublicReturnPass ());
114+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
115+ // The decompose pass is run again here since the scalarize shapes pass and
116+ // shape refinement pipeline might create some ops for which decomposition
117+ // exists.
118+ if (options.decompose ) {
119+ pm.addNestedPass <func::FuncOp>(
120+ Torch::createDecomposeComplexOpsPass (options.backendLegalOps ));
121+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
122+ }
123+ }
124+
89125// A simplification pipeline to establish the invariants of the backend
90126// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
91127//
0 commit comments