Skip to content

Commit bfc106e

Browse files
authored
Triggering auto torch input conversion based on func arg/result types. (#21067)
Fixes #21066.
1 parent 454b68c commit bfc106e

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,20 @@
22

33
// Check that the auto input conversion pipeline uses this plugin.
44

5-
65
// CHECK-LABEL: util.func public @simple_add_onnx
76
// CHECK: arith.addi
87
func.func @simple_add_onnx(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
98
%0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
109
return %0 : !torch.vtensor<[],si64>
1110
}
11+
12+
// -----
13+
14+
// Tests that a function using torch types but not containing any ops is still
15+
// handled by the torch input pipeline.
16+
17+
// CHECK: util.func public @nop$async
18+
// CHECK: util.func public @nop(%{{.+}}: !hal.buffer_view) -> !hal.buffer_view
19+
func.func @nop(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> attributes {torch.assume_strict_symbolic_shapes} {
20+
return %arg0 : !torch.vtensor<[5],f32>
21+
}

compiler/plugins/input/Torch/PluginRegistration.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,18 @@ struct TorchSession
129129
return WalkResult::advance();
130130
});
131131

132+
auto isTorchType = [&](Type type) {
133+
return &type.getDialect() == torchDialect;
134+
};
132135
for (auto funcOp : module.getOps<func::FuncOp>()) {
136+
auto funcType = funcOp.getFunctionType();
137+
if (llvm::any_of(funcType.getInputs(), isTorchType) ||
138+
llvm::any_of(funcType.getResults(), isTorchType)) {
139+
hasTorch = true;
140+
}
133141
if (funcOp->getAttrOfType<mlir::IntegerAttr>(
134142
"torch.onnx_meta.opset_version")) {
135143
hasOnnx = true;
136-
break;
137144
}
138145
}
139146

0 commit comments

Comments
 (0)