Skip to content

Commit aced82d

Browse files
authored
Merge pull request #433 from Xilinx/jrickert.use_model_type_for_custom_op
Add new option useOnnxModelTypesForCustomOps to allow using types fro…
2 parents d61af31 + 4c5ced6 commit aced82d

10 files changed

+402
-31
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -414,15 +414,13 @@ class FrontendGenImpl {
414414

415415
std::optional<ErrorOr<Type>> ConvertOnnxType(
416416
const std::string &onnx_name, std::string &errorMessage) {
417-
if (options_.useOnnxModelTypes) {
418-
if (const onnx::TypeProto *onnxTypePtr =
419-
onnx_type_map.GetByOnnxName(onnx_name)) {
420-
ErrorOr<Type> importedType = ImportType(*onnxTypePtr, errorMessage);
421-
if (auto ec = importedType.getError()) {
422-
return ec;
423-
}
424-
return *importedType;
417+
if (const onnx::TypeProto *onnxTypePtr =
418+
onnx_type_map.GetByOnnxName(onnx_name)) {
419+
ErrorOr<Type> importedType = ImportType(*onnxTypePtr, errorMessage);
420+
if (auto ec = importedType.getError()) {
421+
return ec;
425422
}
423+
return *importedType;
426424
}
427425
return std::nullopt;
428426
}
@@ -827,7 +825,8 @@ class FrontendGenImpl {
827825
const onnx::NodeProto &node, std::vector<Value> inputs,
828826
int expectedNumOperands, int expectedNumResults,
829827
const std::vector<NamedAttribute> &attributes, std::string &errorMessage,
830-
std::vector<Type> givenOutputTypes = std::vector<Type>()) {
828+
std::vector<Type> givenOutputTypes = std::vector<Type>(),
829+
bool isCustomOp = false) {
831830
bool variadicIn = expectedNumOperands == -1;
832831
bool variadicOut = expectedNumResults == -1;
833832

@@ -854,20 +853,24 @@ class FrontendGenImpl {
854853
if (node.output()[i].empty()) {
855854
outputTypes.emplace_back(builder_.getNoneType());
856855
} else {
857-
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
858-
if (onnxModelType) {
859-
const auto ec = onnxModelType->getError();
860-
if (!ec) {
861-
outputTypes.emplace_back(*onnxModelType.value());
862-
continue;
863-
}
864-
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
865-
errorMessage += "Failed to get type for '" + node.output(i) + "\n";
866-
return ec;
856+
if (options_.useOnnxModelTypes ||
857+
(isCustomOp && options_.useOnnxModelTypesForCustomOps)) {
858+
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
859+
if (onnxModelType) {
860+
const auto ec = onnxModelType->getError();
861+
if (!ec) {
862+
outputTypes.emplace_back(*onnxModelType.value());
863+
continue;
864+
}
865+
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866+
errorMessage +=
867+
"Failed to get type for '" + node.output(i) + "\n";
868+
return ec;
869+
}
870+
llvm::errs() << "Warning: "
871+
<< "Failed to get type type for '" << node.output(i)
872+
<< "', falling back to onnx-mlir based mapping.\n";
867873
}
868-
llvm::errs() << "Warning: "
869-
<< "Failed to get type type for '" << node.output(i)
870-
<< "', falling back to onnx-mlir based mapping.\n";
871874
}
872875
unsigned int j = i;
873876
// Variadic output is a single ODS result.
@@ -931,6 +934,8 @@ class FrontendGenImpl {
931934
}
932935
}
933936
}
937+
// Note: ResultTypeInferenceOpInterface only infers the type of the result,
938+
// not the shape
934939
if (auto opWithTypeInference =
935940
mlir::dyn_cast<ResultTypeInferenceOpInterface>(op.getOperation())) {
936941
auto outTypes = opWithTypeInference.resultTypeInference();
@@ -1456,7 +1461,8 @@ class FrontendGenImpl {
14561461
onnx::OpSchemaRegistry::Instance(),
14571462
/*options=*/{}, in_model_functions_);
14581463
} catch (const std::exception &e) {
1459-
llvm::errs() << "Warning: Caught exception running onnx shape inference: "
1464+
llvm::errs() << "Warning: Caught exception running onnx shape inference "
1465+
"to populate graph.value_info: "
14601466
<< e.what() << "\n";
14611467
}
14621468

@@ -1584,8 +1590,8 @@ class FrontendGenImpl {
15841590

15851591
// ToFix: The type inference may go wrong if the element type of the output
15861592
// of CustomOp is not the same as the first input.
1587-
return buildOutputAndOperation<ONNXCustomOp>(
1588-
node, inputs, nIn, nOut, attributes, errorMessage, givenOutputTypes);
1593+
return buildOutputAndOperation<ONNXCustomOp>(node, inputs, nIn, nOut,
1594+
attributes, errorMessage, givenOutputTypes, /*isCustomOp*/ true);
15891595
}
15901596

15911597
[[nodiscard]] std::error_code ImportNode(
@@ -1806,7 +1812,7 @@ class FrontendGenImpl {
18061812
originVersion < CURRENT_ONNX_OPSET) {
18071813
onnx::ModelProto convertModel =
18081814
onnx::version_conversion::ConvertVersion(model, CURRENT_ONNX_OPSET);
1809-
if (options.useOnnxModelTypes) {
1815+
if (options.runOnnxShapeInference) {
18101816
try {
18111817
onnx::shape_inference::InferShapes(convertModel);
18121818
} catch (const std::exception &e) {
@@ -1818,7 +1824,7 @@ class FrontendGenImpl {
18181824
return ImportFrontendModel(
18191825
convertModel, context, module, errorMessage, options);
18201826
} else {
1821-
if (options.useOnnxModelTypes) {
1827+
if (options.runOnnxShapeInference) {
18221828
try {
18231829
onnx::shape_inference::InferShapes(model);
18241830
} catch (const std::exception &e) {

src/Builder/FrontendDialectTransformer.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ struct ImportOptions {
4343
bool verboseOutput = false;
4444
// Use types/shapes in the input-model for translation (for intermediate
4545
// variables)
46-
bool useOnnxModelTypes = false;
46+
bool useOnnxModelTypes = true;
47+
bool runOnnxShapeInference = true;
48+
bool useOnnxModelTypesForCustomOps = true;
4749
bool invokeOnnxVersionConverter = false;
4850
bool allowSorting = true;
4951
bool useOutputNameAsLocation = false;

src/Compiler/CompilerOptions.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ bool preserveBitcode; // onnx-mlir only
6464
bool preserveLLVMIR; // onnx-mlir only
6565
bool preserveMLIR; // onnx-mlir only
6666
bool useOnnxModelTypes; // onnx-mlir only
67+
bool runOnnxShapeInference; // onnx-mlir only
68+
bool useOnnxModelTypesForCustomOps; // onnx-mlir only
6769
int repeatOnnxTransform; // onnx-mlir only
6870
std::string shapeInformation; // onnx-mlir only
6971
std::string dimParams; // onnx-mlir only
@@ -389,6 +391,20 @@ static llvm::cl::opt<bool, true> useOnnxModelTypesOpt("useOnnxModelTypes",
389391
llvm::cl::location(useOnnxModelTypes), llvm::cl::init(true),
390392
llvm::cl::cat(OnnxMlirOptions));
391393

394+
static llvm::cl::opt<bool, true> runOnnxShapeInferenceOpt(
395+
"runOnnxShapeInference",
396+
llvm::cl::desc("Run ONNX shape inference when importing a model. This is "
397+
"independent of the shape inference in ONNX-MLIR"),
398+
llvm::cl::location(runOnnxShapeInference), llvm::cl::init(true),
399+
llvm::cl::cat(OnnxMlirOptions));
400+
401+
static llvm::cl::opt<bool, true> useOnnxModelTypesForCustomOpsOpt(
402+
"useOnnxModelTypesForCustomOps",
403+
llvm::cl::desc("Use types and shapes from ONNX model for custom ops, even "
404+
"if `useOnnxModelTypes` is disabled."),
405+
llvm::cl::location(useOnnxModelTypesForCustomOps), llvm::cl::init(true),
406+
llvm::cl::cat(OnnxMlirOptions));
407+
392408
static llvm::cl::opt<bool, true> useOutputNameAsLocationOpt(
393409
"useOutputNameAsLocation", llvm::cl::desc("Use output name as location."),
394410
llvm::cl::location(useOutputNameAsLocation), llvm::cl::init(false),

src/Compiler/CompilerOptions.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ extern bool preserveLLVMIR; // onnx-mlir only
9999
extern bool preserveMLIR; // onnx-mlir only
100100
extern bool doNotEmitFullMLIRCode; // onnx-mlir only
101101
extern bool useOnnxModelTypes; // onnx-mlir only
102+
extern bool runOnnxShapeInference; // onnx-mlir only
103+
extern bool useOnnxModelTypesForCustomOps; // onnx-mlir only
102104
extern int repeatOnnxTransform; // onnx-mlir only
103105
extern std::string shapeInformation; // onnx-mlir only
104106
extern std::string dimParams; // onnx-mlir only

src/Compiler/CompilerUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,8 @@ std::string dirName(StringRef inputFilename) {
721721
ImportOptions options;
722722
options.verboseOutput = VerboseOutput;
723723
options.useOnnxModelTypes = useOnnxModelTypes;
724+
options.runOnnxShapeInference = runOnnxShapeInference;
725+
options.useOnnxModelTypesForCustomOps = useOnnxModelTypesForCustomOps;
724726
options.useOutputNameAsLocation = useOutputNameAsLocation;
725727
options.allowMissingOutputTypes = allowMissingOutputTypes;
726728
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;

test/mlir/onnx/parse/add_missing_output_types.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
// Json generated with utils/testing/add_missing_output_types.py
22

3-
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --allowMissingOutputTypes=false --printIR %s 2> failed.log; cat failed.log | FileCheck --check-prefix=FAILURE %s
3+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=false --allowMissingOutputTypes=false --printIR %s 2> failed.log; cat failed.log | FileCheck --check-prefix=FAILURE %s
44

55
// FAILURE: Could not successfully parse ONNX file
66
// FAILURE: ONNX type with id: 0 is not a valid type
77
// FAILURE: Failed to import output type for
88
// FAILURE: Failed to import main graph, could not get its function type
99

10-
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=INFERRED %s
10+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=false --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=INFERRED %s
1111
// INFERRED-LABEL: func.func @main_graph
1212
// INFERRED-SAME: ([[PARAM_0_:%.+]]: tensor<3x1xf32> {onnx.name = "input_a"}, [[PARAM_1_:%.+]]: tensor<1x3xf32> {onnx.name = "input_b"}) -> (tensor<*xf32> {onnx.name = "output_c"}, tensor<3x3xf32> {onnx.name = "output_d"}) {
1313
// INFERRED-DAG: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {domain_name = "test", function_name = "test.Add", onnx_node_name = "add_node_custom"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<*xf32>
1414
// INFERRED-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx_node_name = "add_node"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
1515
// INFERRED: return [[VAR_0_]], [[VAR_1_]] : tensor<*xf32>, tensor<3x3xf32>
1616
// INFERRED: }
1717

18-
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=true --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=MODEL-TYPE %s
18+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=true --useOnnxModelTypesForCustomOps=true --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=MODEL-TYPE %s
1919
// MODEL-TYPE-LABEL: func.func @main_graph
2020
// MODEL-TYPE-SAME: ([[PARAM_0_:%.+]]: tensor<3x1xf32> {onnx.name = "input_a"}, [[PARAM_1_:%.+]]: tensor<1x3xf32> {onnx.name = "input_b"}) -> (tensor<*xf32> {onnx.name = "output_c"}, tensor<3x3xf32> {onnx.name = "output_d"}) {
2121
// MODEL-TYPE-DAG: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {domain_name = "test", function_name = "test.Add", onnx_node_name = "add_node_custom"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<*xf32>
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=true --printIR %s | FileCheck %s
2+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --useOnnxModelTypesForCustomOps=false --printIR %s | FileCheck %s --check-prefix=NO-TYPES-FROM-MODEL
3+
4+
// Json generated with utils/testing/custom_shape_from_model.py
5+
6+
// CHECK-LABEL: func.func @main_graph
7+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x5xf32> {onnx.name = "X"}) -> (tensor<5x6xf32> {onnx.name = "Z"}) {
8+
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]]) {domain_name = "test", function_name = "MyCustomOp", onnx_node_name = "onnx.Custom_0"} : (tensor<4x5xf32>) -> tensor<5x6xf32>
9+
// CHECK: [[VAR_1_:%.+]] = "onnx.Relu"([[VAR_0_]]) {onnx_node_name = "onnx.Relu_1"} : (tensor<5x6xf32>) -> tensor<5x6xf32>
10+
// CHECK: return [[VAR_1_]] : tensor<5x6xf32>
11+
// CHECK: }
12+
// NO-TYPES-FROM-MODEL-LABEL: func.func @main_graph
13+
// NO-TYPES-FROM-MODEL-SAME: ([[PARAM_0_:%.+]]: tensor<4x5xf32> {onnx.name = "X"}) -> (tensor<5x6xf32> {onnx.name = "Z"}) {
14+
// NO-TYPES-FROM-MODEL: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]]) {domain_name = "test", function_name = "MyCustomOp", onnx_node_name = "onnx.Custom_0"} : (tensor<4x5xf32>) -> tensor<*xf32>
15+
// NO-TYPES-FROM-MODEL: [[VAR_1_:%.+]] = "onnx.Relu"([[VAR_0_]]) {onnx_node_name = "onnx.Relu_1"} : (tensor<*xf32>) -> tensor<5x6xf32>
16+
// NO-TYPES-FROM-MODEL: return [[VAR_1_]] : tensor<5x6xf32>
17+
// NO-TYPES-FROM-MODEL: }
18+
{
19+
"irVersion": "10",
20+
"producerName": "onnx-mlir",
21+
"graph": {
22+
"node": [
23+
{
24+
"input": [
25+
"X"
26+
],
27+
"output": [
28+
"Y"
29+
],
30+
"opType": "MyCustomOp",
31+
"domain": "test"
32+
},
33+
{
34+
"input": [
35+
"Y"
36+
],
37+
"output": [
38+
"Z"
39+
],
40+
"opType": "Relu"
41+
}
42+
],
43+
"name": "test-custom",
44+
"input": [
45+
{
46+
"name": "X",
47+
"type": {
48+
"tensorType": {
49+
"elemType": 1,
50+
"shape": {
51+
"dim": [
52+
{
53+
"dimValue": "4"
54+
},
55+
{
56+
"dimValue": "5"
57+
}
58+
]
59+
}
60+
}
61+
}
62+
}
63+
],
64+
"output": [
65+
{
66+
"name": "Z",
67+
"type": {
68+
"tensorType": {
69+
"elemType": 1,
70+
"shape": {
71+
"dim": [
72+
{
73+
"dimValue": "5"
74+
},
75+
{
76+
"dimValue": "6"
77+
}
78+
]
79+
}
80+
}
81+
}
82+
}
83+
],
84+
"valueInfo": [
85+
{
86+
"name": "Y",
87+
"type": {
88+
"tensorType": {
89+
"elemType": 1,
90+
"shape": {
91+
"dim": [
92+
{
93+
"dimValue": "5"
94+
},
95+
{
96+
"dimValue": "6"
97+
}
98+
]
99+
}
100+
}
101+
}
102+
}
103+
]
104+
},
105+
"opsetImport": [
106+
{
107+
"domain": "",
108+
"version": "22"
109+
},
110+
{
111+
"domain": "test",
112+
"version": "1"
113+
}
114+
]
115+
}

0 commit comments

Comments
 (0)