Skip to content

Commit 2eca0d3

Browse files
committed
Make parsing of models more robust if useOnnxModelTypes is enabled.
This is done by catching exceptions during shape inference (as they happen for example if the model uses custom ops) and by falling back to an onnx-mlir based type mapping for some kinds of invalid models. Pass ImportOptions by const reference instead of value to avoid unnecessay copying. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent faf797f commit 2eca0d3

File tree

3 files changed

+80
-48
lines changed

3 files changed

+80
-48
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,19 @@ using SymbolToOnnxTypeMapping = SymbolMapping<onnx::TypeProto>;
159159

160160
class FrontendGenImpl {
161161
public:
162-
explicit FrontendGenImpl(MLIRContext &context)
163-
: context_(context), builder_(&context) {
162+
explicit FrontendGenImpl(MLIRContext &context, const ImportOptions &options)
163+
: options_(options), context_(context), builder_(&context) {
164164
module_ = ModuleOp::create(UnknownLoc::get(&context));
165165
InitHandlerMap();
166166
}
167167

168-
ErrorOr<ModuleOp> ImportONNXModel(const onnx::ModelProto &model,
169-
ImportOptions options, std::string &errorMessage) {
170-
options_ = options;
168+
ErrorOr<ModuleOp> ImportONNXModel(
169+
const onnx::ModelProto &model, std::string &errorMessage) {
171170
modelInputShaper_.setShapeInformation(options_.shapeInformation);
172171
opset_map_ = GetOpsetImportsFromProto(model); // Which opsets to use.
173172
in_model_functions_ = GetModelLocalFunctions(model);
174-
ErrorOr<mlir::func::FuncOp> importGraphResult = importGraph(
175-
model.graph(), options_.allowMissingOutputTypes, errorMessage);
173+
ErrorOr<mlir::func::FuncOp> importGraphResult =
174+
importGraph(model.graph(), errorMessage);
176175
if (auto ec = importGraphResult.getError()) {
177176
return ec;
178177
}
@@ -193,7 +192,7 @@ class FrontendGenImpl {
193192
}
194193

195194
private:
196-
ImportOptions options_;
195+
const ImportOptions &options_;
197196
MLIRContext &context_;
198197
ModuleOp module_;
199198
OpBuilder builder_;
@@ -541,7 +540,7 @@ class FrontendGenImpl {
541540
*/
542541
ErrorOr<FunctionType> importGraph(const onnx::GraphProto &graph,
543542
Region &region, Operation *op, bool useReturn,
544-
bool allowMissingOutputTypes, std::string &errorMessage) {
543+
std::string &errorMessage) {
545544
frontend_symbols_.pushScope(graph.name());
546545
onnx_type_map.pushScope(graph.name());
547546
Block *entryBlock = &region.back();
@@ -648,8 +647,8 @@ class FrontendGenImpl {
648647
// Import the output tensors
649648
for (const auto &output : graph.output()) {
650649
std::string dimParams = "";
651-
const auto ec = ImportOutputTensor(output, retTys, retVals, errorMessage,
652-
allowMissingOutputTypes, &dimParams);
650+
const auto ec =
651+
ImportOutputTensor(output, retTys, retVals, errorMessage, &dimParams);
653652
if (ec) {
654653
errorMessage +=
655654
"Failed to import output tensor '" + output.name() + "'.\n";
@@ -854,13 +853,22 @@ class FrontendGenImpl {
854853
// Optional outputs using empty string.
855854
if (node.output()[i].empty()) {
856855
outputTypes.emplace_back(builder_.getNoneType());
857-
} else if (auto onnxModelType =
858-
ConvertOnnxType(node.output(i), errorMessage)) {
859-
if (auto ec = onnxModelType->getError()) {
860-
return ec;
861-
}
862-
outputTypes.emplace_back(*onnxModelType.value());
863856
} else {
857+
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
858+
if (onnxModelType) {
859+
const auto ec = onnxModelType->getError();
860+
if (!ec) {
861+
outputTypes.emplace_back(*onnxModelType.value());
862+
continue;
863+
}
864+
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
865+
errorMessage += "Failed to get type for '" + node.output(i) + "\n";
866+
return ec;
867+
}
868+
llvm::errs() << "Warning: "
869+
<< "Failed to get type type for '" << node.output(i)
870+
<< "', falling back to onnx-mlir based mapping.\n";
871+
}
864872
unsigned int j = i;
865873
// Variadic output is a single ODS result.
866874
if (variadicOut)
@@ -910,8 +918,8 @@ class FrontendGenImpl {
910918
region.push_back(new Block);
911919
OpBuilder::InsertionGuard guard(builder_);
912920
builder_.setInsertionPointToStart(&region.back());
913-
const ErrorOr<FunctionType> importGraphResult = importGraph(attr.g(),
914-
region, op, false, options_.allowMissingOutputTypes, errorMessage);
921+
const ErrorOr<FunctionType> importGraphResult =
922+
importGraph(attr.g(), region, op, false, errorMessage);
915923
if (auto ec = importGraphResult.getError()) {
916924
return ec;
917925
}
@@ -1443,9 +1451,14 @@ class FrontendGenImpl {
14431451
GetOpsetImportsFromProto(functionProto);
14441452

14451453
// Populates graph.value_info().
1446-
onnx::shape_inference::InferShapes(&graph, function_opset_map,
1447-
onnx::OpSchemaRegistry::Instance(),
1448-
/*options=*/{}, in_model_functions_);
1454+
try {
1455+
onnx::shape_inference::InferShapes(&graph, function_opset_map,
1456+
onnx::OpSchemaRegistry::Instance(),
1457+
/*options=*/{}, in_model_functions_);
1458+
} catch (const std::exception &e) {
1459+
llvm::errs() << "Warning: Caught exception running onnx shape inference: "
1460+
<< e.what() << "\n";
1461+
}
14491462

14501463
// Save caller context, while generating function body.
14511464
ModelLocalFunctionsMap callerModelFunctions;
@@ -1629,14 +1642,14 @@ class FrontendGenImpl {
16291642
const onnx::ValueInfoProto &output,
16301643
llvm::SmallVectorImpl<Type> &ret_types,
16311644
llvm::SmallVectorImpl<Value> &ret_vals, std::string &errorMessage,
1632-
bool allowMissingType, std::string *dim_params = nullptr) {
1645+
std::string *dim_params = nullptr) {
16331646
const Value *valPtr = frontend_symbols_.GetByOnnxName(output.name());
16341647
Value val = *valPtr;
16351648

16361649
ErrorOr<Type> parsedOutputType =
16371650
ImportType(output.type(), errorMessage, dim_params);
16381651
if (auto ec = parsedOutputType.getError()) {
1639-
if (!allowMissingType || ec != InvalidOnnxFormat) {
1652+
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
16401653
errorMessage +=
16411654
"Failed to import output type for '" + output.name() + "\n";
16421655
return ec;
@@ -1716,15 +1729,10 @@ class FrontendGenImpl {
17161729
/*!
17171730
* Import ONNX main computation graph.
17181731
* @param graph onnx graph proto.
1719-
* @param allowMissingOutputTypes If true, type inference will be used to
1720-
* infer missing output types. This is done by copying the, potential
1721-
* inferred, output type of the node connected to the output. According to
1722-
* ONNX, all outputs MUST have types. Therefore this option has to be
1723-
* considered as a stretch best effort.
17241732
* @return A function corresponding to the imported computation graph.
17251733
*/
1726-
ErrorOr<func::FuncOp> importGraph(const onnx::GraphProto &graph,
1727-
bool allowMissingOutputTypes, std::string &errorMessage) {
1734+
ErrorOr<func::FuncOp> importGraph(
1735+
const onnx::GraphProto &graph, std::string &errorMessage) {
17281736
const std::string &name = "main_graph";
17291737
auto mainFunc = func::FuncOp::create(UnknownLoc(), name,
17301738
/*type=*/builder_.getFunctionType({}, {}), /*attrs=*/{});
@@ -1735,8 +1743,7 @@ class FrontendGenImpl {
17351743

17361744
ErrorOr<FunctionType> importedFuncType =
17371745
importGraph(graph, /*region=*/mainFunc.getBody(),
1738-
/*op=*/mainFunc.getOperation(), /*useReturn=*/true,
1739-
/*allowMissingOutputTypes=*/allowMissingOutputTypes, errorMessage);
1746+
/*op=*/mainFunc.getOperation(), /*useReturn=*/true, errorMessage);
17401747
if (auto ec = importedFuncType.getError()) {
17411748
errorMessage +=
17421749
"Failed to import main graph, could not get its function type\n";
@@ -1763,7 +1770,7 @@ class FrontendGenImpl {
17631770

17641771
[[nodiscard]] std::error_code ImportFrontendModelInternal(
17651772
onnx::ModelProto &model, MLIRContext &context,
1766-
OwningOpRef<ModuleOp> &module, ImportOptions options,
1773+
OwningOpRef<ModuleOp> &module, const ImportOptions &options,
17671774
std::string &errorMessage) {
17681775
int originVersion = CURRENT_ONNX_OPSET;
17691776
// Get the version of the model
@@ -1799,21 +1806,35 @@ class FrontendGenImpl {
17991806
originVersion < CURRENT_ONNX_OPSET) {
18001807
onnx::ModelProto convertModel =
18011808
onnx::version_conversion::ConvertVersion(model, CURRENT_ONNX_OPSET);
1802-
if (options.useOnnxModelTypes)
1803-
onnx::shape_inference::InferShapes(convertModel);
1809+
if (options.useOnnxModelTypes) {
1810+
try {
1811+
onnx::shape_inference::InferShapes(convertModel);
1812+
} catch (const std::exception &e) {
1813+
llvm::errs()
1814+
<< "Warning: Caught exception running onnx shape inference: "
1815+
<< e.what() << "\n";
1816+
}
1817+
}
18041818
return ImportFrontendModel(
18051819
convertModel, context, module, errorMessage, options);
18061820
} else {
1807-
if (options.useOnnxModelTypes)
1808-
onnx::shape_inference::InferShapes(model);
1821+
if (options.useOnnxModelTypes) {
1822+
try {
1823+
onnx::shape_inference::InferShapes(model);
1824+
} catch (const std::exception &e) {
1825+
llvm::errs()
1826+
<< "Warning: Caught exception running onnx shape inference: "
1827+
<< e.what() << "\n";
1828+
}
1829+
}
18091830
return ImportFrontendModel(model, context, module, errorMessage, options);
18101831
}
18111832
return CompilerSuccess;
18121833
}
18131834

18141835
[[nodiscard]] std::error_code ImportFrontendModelArray(const void *onnxBuffer,
18151836
int size, MLIRContext &context, OwningOpRef<ModuleOp> &module,
1816-
std::string &errorMessage, ImportOptions options) {
1837+
std::string &errorMessage, const ImportOptions &options) {
18171838
onnx::ModelProto model;
18181839

18191840
bool parse_success = model.ParseFromArray(onnxBuffer, size);
@@ -1855,7 +1876,7 @@ namespace {
18551876
// Return 0 on success, error otherwise.
18561877
[[nodiscard]] std::error_code ImportFrontendModelFile(StringRef model_fname,
18571878
MLIRContext &context, OwningOpRef<ModuleOp> &module,
1858-
std::string &errorMessage, ImportOptions options) {
1879+
std::string &errorMessage, const ImportOptions &options) {
18591880
onnx::ModelProto model;
18601881
if (model_fname.ends_with(".onnxtext")) {
18611882
std::string text;
@@ -1912,11 +1933,11 @@ namespace {
19121933

19131934
[[nodiscard]] std::error_code ImportFrontendModel(const onnx::ModelProto &model,
19141935
MLIRContext &context, OwningOpRef<ModuleOp> &module,
1915-
std::string &errorMessage, ImportOptions options) {
1936+
std::string &errorMessage, const ImportOptions &options) {
19161937

1917-
detail::FrontendGenImpl myONNXGen(context);
1938+
detail::FrontendGenImpl myONNXGen(context, options);
19181939
ErrorOr<ModuleOp> importedModule =
1919-
myONNXGen.ImportONNXModel(model, options, errorMessage);
1940+
myONNXGen.ImportONNXModel(model, errorMessage);
19201941
if (auto ec = importedModule.getError()) {
19211942
return ec;
19221943
}

src/Builder/FrontendDialectTransformer.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef ONNX_MLIR_FRONTEND_TRANSFORMER_H
1414
#define ONNX_MLIR_FRONTEND_TRANSFORMER_H
1515

16-
#include <set>
1716
#include <string>
1817

1918
#include "onnx/onnx_pb.h"
@@ -49,7 +48,11 @@ struct ImportOptions {
4948
bool allowSorting = true;
5049
bool useOutputNameAsLocation = false;
5150

52-
// Allow missing output types and use type inference to determine them.
51+
// If true, type inference will be used to
52+
// infer missing output types. This is done by copying the, potential
53+
// inferred, output type of the node connected to the output. According to
54+
// ONNX, all outputs MUST have types. Therefore this option has to be
55+
// considered as a stretch best effort.
5356
bool allowMissingOutputTypes = false;
5457

5558
// Custom shape information for the graph inputs.
@@ -90,7 +93,7 @@ struct ImportOptions {
9093
[[nodiscard]] std::error_code ImportFrontendModelArray(const void *onnxBuffer,
9194
int bufferSize, mlir::MLIRContext &context,
9295
mlir::OwningOpRef<mlir::ModuleOp> &module, std::string &errorMessage,
93-
ImportOptions options = ImportOptions());
96+
const ImportOptions &options = ImportOptions());
9497

9598
/*!
9699
* Import an ONNX model file into the ONNX Dialect.
@@ -100,7 +103,7 @@ struct ImportOptions {
100103
[[nodiscard]] std::error_code ImportFrontendModelFile(
101104
llvm::StringRef model_fname, mlir::MLIRContext &context,
102105
mlir::OwningOpRef<mlir::ModuleOp> &module, std::string &errorMessage,
103-
ImportOptions options = ImportOptions());
106+
const ImportOptions &options = ImportOptions());
104107

105108
/*!
106109
* Import an ONNX model proto into the ONNX Dialect.
@@ -109,7 +112,7 @@ struct ImportOptions {
109112
*/
110113
[[nodiscard]] std::error_code ImportFrontendModel(const onnx::ModelProto &model,
111114
mlir::MLIRContext &context, mlir::OwningOpRef<mlir::ModuleOp> &module,
112-
std::string &errorMessage, ImportOptions options = ImportOptions());
115+
std::string &errorMessage, const ImportOptions &options = ImportOptions());
113116

114117
/*!
115118
* TODO: Import models into other extension dialects that cover the

test/mlir/onnx/parse/add_missing_output_types.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
// INFERRED: return [[VAR_0_]], [[VAR_1_]] : tensor<*xf32>, tensor<3x3xf32>
1616
// INFERRED: }
1717

18+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=true --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=MODEL-TYPE %s
19+
// MODEL-TYPE-LABEL: func.func @main_graph
20+
// MODEL-TYPE-SAME: ([[PARAM_0_:%.+]]: tensor<3x1xf32> {onnx.name = "input_a"}, [[PARAM_1_:%.+]]: tensor<1x3xf32> {onnx.name = "input_b"}) -> (tensor<*xf32> {onnx.name = "output_c"}, tensor<3x3xf32> {onnx.name = "output_d"}) {
21+
// MODEL-TYPE-DAG: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {domain_name = "test", function_name = "test.Add", onnx_node_name = "add_node_custom"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<*xf32>
22+
// MODEL-TYPE-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx_node_name = "add_node"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
23+
// MODEL-TYPE: return [[VAR_0_]], [[VAR_1_]] : tensor<*xf32>, tensor<3x3xf32>
24+
// MODEL-TYPE: }
25+
1826
{
1927
"irVersion": "10",
2028
"producerName": "onnx-example",

0 commit comments

Comments
 (0)