Skip to content

Commit fa4794d

Browse files
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
This commit adds the torch-onnx-to-torch-backend pipeline which converts the Torch Onnx IR to Torch Backend IR. This commit also moves the `ScalarizeShapes` pass from the `torch-backend-to-linalg-on-tensors-backend-pipeline` to the `torch-onnx-to-torch-backend` pipeline since the primary goal of this pass is to scalarize the shapes in the IR coming from the Onnx models.
1 parent d2330df commit fa4794d

File tree

5 files changed

+117
-18
lines changed

5 files changed

+117
-18
lines changed

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ void createTorchDynamoExportToTorchBackendPipeline(
8484
void createTorchFunctionToTorchBackendPipeline(
8585
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
8686

87+
/// Creates a pipeline that lowers the torch Onnx IR that is produced by
88+
/// Onnx import into the form expected by torch-verify-backend-contract.
89+
void createTorchOnnxToTorchBackendPipeline(
90+
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
91+
8792
/// Creates a pipeline that simplifies the computations in the program.
8893
/// This pass does not do any global program restructuring -- it works entirely
8994
/// within a single semantic model of a `builtin.module` with

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
1111
#include "mlir/Pass/PassManager.h"
1212
#include "mlir/Transforms/Passes.h"
13+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
1314

1415
void mlir::torch::registerTorchPasses() {
1516
mlir::torch::registerPasses();
@@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() {
2526
"torch-function-to-torch-backend-pipeline",
2627
"Pipeline lowering a Torch function to Torch backend form.",
2728
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
29+
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
30+
"torch-onnx-to-torch-backend-pipeline",
31+
"Pipeline lowering Torch Onnx IR to Torch backend form.",
32+
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline);
2833
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
2934
"torch-simplification-pipeline",
3035
"Pipeline simplifying computations in the program.",
@@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
8691
options.backendLegalOps, options.extraLibrary));
8792
}
8893

94+
void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
95+
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
96+
pm.addNestedPass<func::FuncOp>(onnx_c::createTorchOnnxToTorchPass());
97+
// The above pass just converts the torch onnx IR to torch, hence the given
98+
// pipeline will make sure that the IR is transformed such that it satisfies
99+
// the backend contract.
100+
if (options.decompose) {
101+
pm.addNestedPass<func::FuncOp>(
102+
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
103+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
104+
}
105+
// TODO: Move the combination of two passes i.e., ScalarizeShapes and
106+
// TorchShapeRefinementPipeline out of here and create an onnx shape
107+
// refinement pipeline which runs iteratively over the IR.
108+
createTorchShapeRefinementPipeline(pm, options);
109+
// This pass scalarizes the tensor shape computations.
110+
pm.addNestedPass<mlir::func::FuncOp>(
111+
mlir::torch::Torch::createScalarizeShapesPass());
112+
createTorchShapeRefinementPipeline(pm, options);
113+
pm.addPass(Torch::createRefinePublicReturnPass());
114+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
115+
// The decompose pass is run again here since the scalarize shapes pass and
116+
// shape refinement pipeline might create some ops for which decomposition
117+
// exists.
118+
if (options.decompose) {
119+
pm.addNestedPass<func::FuncOp>(
120+
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
121+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
122+
}
123+
}
124+
89125
// A simplification pipeline to establish the invariants of the backend
90126
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
91127
//

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
7070

7171
// We want to fuse quantized operations together before lowering to linalg.
7272
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
73-
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
7473

7574
// Lower to linalg + guards which is the input to codegen backends.
7675
// We do this first as it tends to involve pattern-matching against constants,

projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,33 +100,25 @@ def _module_lowering(
100100
print("ONNX RAW IR")
101101
print(torch_mod)
102102

103-
# Lower from ONNX to Torch
104-
run_pipeline_with_repro_report(
105-
torch_mod,
106-
# The importer may produce additional MLIR functions corresponding to
107-
# ONNX operators that are functions. In some cases they need to be
108-
# inlined to avoid the backend choking on them.
109-
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
110-
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
111-
)
112-
113-
if verbose:
114-
print("\n====================")
115-
print("TorchFX IR")
116-
print(torch_mod)
117-
118103
backend_legal_ops = [
119104
"aten.flatten.using_ints",
120105
"aten.adaptive_avg_pool1d",
121106
"aten.unflatten.int",
122107
]
123108
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
109+
110+
# Lower from ONNX to Torch
124111
run_pipeline_with_repro_report(
125112
torch_mod,
126-
f"builtin.module(torch-lower-to-backend-contract{option_string})",
127-
"Lowering TorchFX IR -> Torch Backend IR",
113+
f"builtin.module(torch-onnx-to-torch-backend-pipeline{option_string})",
114+
"Lowering Onnx Raw IR -> Torch Backend IR",
128115
)
129116

117+
if verbose:
118+
print("\n====================")
119+
print("Torch IR")
120+
print(torch_mod)
121+
130122
return lower_mlir_module(verbose, output_type, torch_mod)
131123

132124

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_reshape_negative_dim_decompose
4+
func.func @test_reshape_negative_dim_decompose(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
5+
// CHECK: %[[INT2:.+]] = torch.constant.int 2
6+
// CHECK: %[[INT6:.+]] = torch.constant.int 6
7+
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
8+
// CHECK: torch.aten.view %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
9+
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32>
10+
return %0 : !torch.vtensor<[2,6,2],f32>
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: func.func @test_triu_decompose
16+
func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
17+
// CHECK: %[[ZERO_TENSOR:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
18+
// CHECK: %[[INT0:.+]] = torch.constant.int 0
19+
// CHECK: %[[INT1:.+]] = torch.constant.int 1
20+
// CHECK: %[[NONE:.+]] = torch.constant.none
21+
// CHECK: %[[INT4:.+]] = torch.constant.int 4
22+
// CHECK: %[[INT5:.+]] = torch.constant.int 5
23+
// CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT4]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
24+
// CHECK: %[[ARANGE_0:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT5]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
25+
// CHECK: %[[UNSQUEEZE:.+]] = torch.aten.unsqueeze %[[ARANGE]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64>
26+
// CHECK: %[[UNSQUEEZE_0:.+]] = torch.aten.unsqueeze %[[ARANGE_0]], %[[INT0]] : !torch.vtensor<[5],si64>, !torch.int -> !torch.vtensor<[1,5],si64>
27+
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[UNSQUEEZE]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
28+
// CHECK: %[[COND:.+]] = torch.aten.ge.Tensor %[[UNSQUEEZE_0]], %[[ADD]] : !torch.vtensor<[1,5],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,5],i1>
29+
// CHECK: %[[RESULT:.+]] = torch.aten.where.self %[[COND]], %arg0, %[[ZERO_TENSOR]] : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4,5],si64>
30+
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
31+
return %0 : !torch.vtensor<[4,5],si64>
32+
}
33+
34+
// -----
35+
36+
module {
37+
// CHECK-LABEL: func.func @test_scalarize
38+
func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
39+
// CHECK: %[[INT2:.+]] = torch.constant.int 2
40+
// CHECK: %[[INT3:.+]] = torch.constant.int 3
41+
// CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
42+
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
43+
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64>
44+
%2 = torch.operator "onnx.Gather"(%0, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
45+
%3 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
46+
%4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<si64>} : () -> !torch.vtensor<[],si64>
47+
%5 = torch.operator "onnx.Gather"(%3, %4) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
48+
%6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
49+
%7 = torch.operator "onnx.Unsqueeze"(%2, %6) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
50+
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
51+
%9 = torch.operator "onnx.Unsqueeze"(%5, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
52+
%10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_3209> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
53+
%11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64>
54+
%12 = torch.operator "onnx.Reshape"(%arg0, %11) : (!torch.vtensor<[?,?,16,64],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32>
55+
return %12 : !torch.vtensor<[?,?,?],f32>
56+
}
57+
}
58+
59+
{-#
60+
dialect_resources: {
61+
builtin: {
62+
__21: "0x080000000000000000000000",
63+
__22: "0x080000000100000000000000",
64+
_onnx__Concat_3209: "0x080000000004000000000000"
65+
}
66+
}
67+
#-}

0 commit comments

Comments
 (0)