Skip to content

Commit 640fd61

Browse files
authored
Merge pull request #473 from Xilinx/bump_to_d2330df5
[AutoBump] Merge with fixes of d2330df (Oct 21) (89)
2 parents 227c5ee + 678a42e commit 640fd61

File tree

12 files changed

+858
-337
lines changed

12 files changed

+858
-337
lines changed

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ void createTorchDynamoExportToTorchBackendPipeline(
8484
void createTorchFunctionToTorchBackendPipeline(
8585
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
8686

87+
/// Creates a pipeline that lowers the torch Onnx IR that is produced by
88+
/// Onnx import into the form expected by torch-verify-backend-contract.
89+
void createTorchOnnxToTorchBackendPipeline(
90+
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
91+
8792
/// Creates a pipeline that simplifies the computations in the program.
8893
/// This pass does not do any global program restructuring -- it works entirely
8994
/// within a single semantic model of a `builtin.module` with

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3715,6 +3715,12 @@ OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
37153715
//===----------------------------------------------------------------------===//
37163716

37173717
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
3718+
auto intLhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
3719+
auto intRhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
3720+
if (intRhs && intRhs.getValue().getSExtValue() == 0)
3721+
return getA();
3722+
if (intLhs && intLhs.getValue().getSExtValue() == 0)
3723+
return getB();
37183724
return atenBinaryIntOperatorFoldHelper(
37193725
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
37203726
}
@@ -3724,6 +3730,9 @@ OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
37243730
//===----------------------------------------------------------------------===//
37253731

37263732
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
3733+
if (getA() == getB())
3734+
return IntegerAttr::get(
3735+
IntegerType::get(getContext(), 64, IntegerType::Signless), 0);
37273736
return atenBinaryIntOperatorFoldHelper(
37283737
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
37293738
}
@@ -4590,7 +4599,8 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
45904599
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
45914600
return intAttr.getType().isUnsignedInteger()
45924601
? getI64IntegerAttr(getContext(), intAttr.getUInt())
4593-
: getI64IntegerAttr(getContext(), intAttr.getSInt());
4602+
: getI64IntegerAttr(getContext(),
4603+
intAttr.getValue().getSExtValue());
45944604
}
45954605
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
45964606
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
1111
#include "mlir/Pass/PassManager.h"
1212
#include "mlir/Transforms/Passes.h"
13+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
1314

1415
void mlir::torch::registerTorchPasses() {
1516
mlir::torch::registerPasses();
@@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() {
2526
"torch-function-to-torch-backend-pipeline",
2627
"Pipeline lowering a Torch function to Torch backend form.",
2728
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
29+
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
30+
"torch-onnx-to-torch-backend-pipeline",
31+
"Pipeline lowering Torch Onnx IR to Torch backend form.",
32+
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline);
2833
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
2934
"torch-simplification-pipeline",
3035
"Pipeline simplifying computations in the program.",
@@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
8691
options.backendLegalOps, options.extraLibrary));
8792
}
8893

94+
void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
95+
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
96+
pm.addNestedPass<func::FuncOp>(onnx_c::createTorchOnnxToTorchPass());
97+
// The above pass just converts the torch onnx IR to torch, hence the given
98+
// pipeline will make sure that the IR is transformed such that it satisfies
99+
// the backend contract.
100+
if (options.decompose) {
101+
pm.addNestedPass<func::FuncOp>(
102+
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
103+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
104+
}
105+
// TODO: Move the combination of two passes i.e., ScalarizeShapes and
106+
// TorchShapeRefinementPipeline out of here and create an onnx shape
107+
// refinement pipeline which runs iteratively over the IR.
108+
createTorchShapeRefinementPipeline(pm, options);
109+
// This pass scalarizes the tensor shape computations.
110+
pm.addNestedPass<mlir::func::FuncOp>(
111+
mlir::torch::Torch::createScalarizeShapesPass());
112+
createTorchShapeRefinementPipeline(pm, options);
113+
pm.addPass(Torch::createRefinePublicReturnPass());
114+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
115+
// The decompose pass is run again here since the scalarize shapes pass and
116+
// shape refinement pipeline might create some ops for which decomposition
117+
// exists.
118+
if (options.decompose) {
119+
pm.addNestedPass<func::FuncOp>(
120+
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
121+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
122+
}
123+
}
124+
89125
// A simplification pipeline to establish the invariants of the backend
90126
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
91127
//

0 commit comments

Comments
 (0)