Skip to content

Commit 68c35d7

Browse files
authored
Pass some decompose-complex-ops options in torch-to-iree (iree-org#19076)
Some problematic decompositions in `torch-decompose-complex-ops` are now disabled by default in the `torch-to-iree` pipeline. Also allows turning off all decompositions with an option. --------- Signed-off-by: zjgarvey <[email protected]>
1 parent 4b15edd commit 68c35d7

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

compiler/plugins/input/Torch/InputConversion/Passes.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ void createTorchToIREEPipeline(
3434
// backends. We do this first as it tends to involve pattern-matching against
3535
// constants, (e.g. dimensions which must be constant in a ranked programming
3636
// model) and those constants get somewhat obscured by TorchToArith.
37-
llvm::ArrayRef<std::string> emptyArrayRef;
38-
3937
// Dynamic shape bindings add a lot of structure to the IR which we prefer to
4038
// leverage and eliminate prior to any other activity, so do this first.
4139
pm.addNestedPass<func::FuncOp>(createBindSymbolicShapesPass());
@@ -51,8 +49,9 @@ void createTorchToIREEPipeline(
5149
torch::Torch::createReduceOpVariantsPass(llvm::StringRef()));
5250
pm.addNestedPass<func::FuncOp>(
5351
mlir::torch::TorchConversion::createConvertCustomQuantOpPass());
54-
pm.addNestedPass<func::FuncOp>(
55-
torch::Torch::createDecomposeComplexOpsPass(emptyArrayRef));
52+
if (options.decompose)
53+
pm.addNestedPass<func::FuncOp>(
54+
torch::Torch::createDecomposeComplexOpsPass(BackendLegalOps::get()));
5655
pm.addNestedPass<func::FuncOp>(torch::Torch::createFuseQuantizedOpsPass());
5756
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
5857
pm.addNestedPass<func::FuncOp>(torch::Torch::createScalarizeShapesPass());

compiler/plugins/input/Torch/InputConversion/Passes.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,28 @@
1313

1414
namespace mlir::iree_compiler::TorchInput {
1515

16+
// The following is a hard-coded list of ops we don't want to decompose in the
17+
// torch dialect, since they have disadvantageous decompositons for the
18+
// torch-to-linalg path. For example, decomposing `aten.flatten.using_ints` to
19+
// `aten.view` simply destroys useful information about what kind of reshape is
20+
// being performed, and hinders our ability, in some cases, to lower this to a
21+
// collapse instead of a generic reshape.
22+
struct BackendLegalOps {
23+
static const llvm::SmallVector<std::string> get() {
24+
return {"aten.flatten.using_ints", "aten.unflatten.int",
25+
"aten.adaptive_avg_pool1d", "aten.adaptive_avg_pool2d",
26+
"aten.adaptive_max_pool1d"};
27+
};
28+
};
29+
1630
struct TorchToIREELoweringPipelineOptions
1731
: public PassPipelineOptions<TorchToIREELoweringPipelineOptions> {
1832
Option<bool> strictSymbolicShapes{
1933
*this, "strict-symbolic-shapes",
2034
llvm::cl::desc("Use strict symbolic shapes."), llvm::cl::init(true)};
35+
Option<bool> decompose{*this, "decompose",
36+
llvm::cl::desc("Decompose complex torch operations."),
37+
llvm::cl::init(true)};
2138
};
2239

2340
// Creates a pipeline that lowers from the torch backend contract to IREE.

compiler/plugins/input/Torch/PluginRegistration.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ namespace {
2323

2424
struct TorchOptions {
2525
bool strictSymbolicShapes = true;
26+
bool decompose = true;
2627
void bindOptions(OptionsBinder &binder) {
2728
static llvm::cl::OptionCategory category("Torch Input");
2829
binder.opt<bool>(
2930
"iree-torch-use-strict-symbolic-shapes", strictSymbolicShapes,
3031
llvm::cl::cat(category),
3132
llvm::cl::desc("Forces dynamic shapes to be treated as strict"));
33+
binder.opt<bool>("iree-torch-decompose-complex-ops", decompose,
34+
llvm::cl::cat(category),
35+
llvm::cl::desc("Decompose complex torch operations."));
3236
}
3337
};
3438

@@ -58,23 +62,17 @@ struct TorchSession
5862
if (typeMnemonic == "onnx") {
5963
// ONNX input is a pre-processing step to torch.
6064
mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions;
61-
// The `aten.flatten.using_ints` and `aten.unflatten.int` are added to the
62-
// list of backend legal ops so that they are not decomposed into the
63-
// `aten.view` op during the run of `DecomposeComplexOps` pass. The issue
64-
// with this is that the `aten.view` op eventually lowers to
65-
// `tensor.reshape` op while there exists a direct torch->linalg lowering
66-
// for both the flatten/unflatten ops which lowers to
67-
// `tensor.collapse_shape/expand_shape` op, and this is a more preferred
68-
// path for the downstream pipeline.
69-
torchOnnxPipelineOptions.backendLegalOps = {"aten.flatten.using_ints",
70-
"aten.unflatten.int"};
65+
torchOnnxPipelineOptions.decompose = options.decompose;
66+
torchOnnxPipelineOptions.backendLegalOps =
67+
TorchInput::BackendLegalOps::get();
7168
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
7269
passManager, torchOnnxPipelineOptions);
7370
}
7471

7572
if (typeMnemonic == "torch" || typeMnemonic == "onnx") {
7673
TorchInput::TorchToIREELoweringPipelineOptions torchOptions;
7774
torchOptions.strictSymbolicShapes = options.strictSymbolicShapes;
75+
torchOptions.decompose = options.decompose;
7876
TorchInput::createTorchToIREEPipeline(passManager, torchOptions);
7977
return true;
8078
}

0 commit comments

Comments
 (0)