Skip to content

Commit d32fa87

Browse files
authored
Merge pull request #404 from Xilinx/jrickert.more_robust_onnx_model_types
Make parsing of models more robust if useOnnxModelTypes is enabled.
2 parents faf797f + 2eca0d3 commit d32fa87

File tree

3 files changed

+80
-48
lines changed

3 files changed

+80
-48
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,19 @@ using SymbolToOnnxTypeMapping = SymbolMapping<onnx::TypeProto>;
159159

160160
class FrontendGenImpl {
161161
public:
162-
explicit FrontendGenImpl(MLIRContext &context)
163-
: context_(context), builder_(&context) {
162+
explicit FrontendGenImpl(MLIRContext &context, const ImportOptions &options)
163+
: options_(options), context_(context), builder_(&context) {
164164
module_ = ModuleOp::create(UnknownLoc::get(&context));
165165
InitHandlerMap();
166166
}
167167

168-
ErrorOr<ModuleOp> ImportONNXModel(const onnx::ModelProto &model,
169-
ImportOptions options, std::string &errorMessage) {
170-
options_ = options;
168+
ErrorOr<ModuleOp> ImportONNXModel(
169+
const onnx::ModelProto &model, std::string &errorMessage) {
171170
modelInputShaper_.setShapeInformation(options_.shapeInformation);
172171
opset_map_ = GetOpsetImportsFromProto(model); // Which opsets to use.
173172
in_model_functions_ = GetModelLocalFunctions(model);
174-
ErrorOr<mlir::func::FuncOp> importGraphResult = importGraph(
175-
model.graph(), options_.allowMissingOutputTypes, errorMessage);
173+
ErrorOr<mlir::func::FuncOp> importGraphResult =
174+
importGraph(model.graph(), errorMessage);
176175
if (auto ec = importGraphResult.getError()) {
177176
return ec;
178177
}
@@ -193,7 +192,7 @@ class FrontendGenImpl {
193192
}
194193

195194
private:
196-
ImportOptions options_;
195+
const ImportOptions &options_;
197196
MLIRContext &context_;
198197
ModuleOp module_;
199198
OpBuilder builder_;
@@ -541,7 +540,7 @@ class FrontendGenImpl {
541540
*/
542541
ErrorOr<FunctionType> importGraph(const onnx::GraphProto &graph,
543542
Region &region, Operation *op, bool useReturn,
544-
bool allowMissingOutputTypes, std::string &errorMessage) {
543+
std::string &errorMessage) {
545544
frontend_symbols_.pushScope(graph.name());
546545
onnx_type_map.pushScope(graph.name());
547546
Block *entryBlock = &region.back();
@@ -648,8 +647,8 @@ class FrontendGenImpl {
648647
// Import the output tensors
649648
for (const auto &output : graph.output()) {
650649
std::string dimParams = "";
651-
const auto ec = ImportOutputTensor(output, retTys, retVals, errorMessage,
652-
allowMissingOutputTypes, &dimParams);
650+
const auto ec =
651+
ImportOutputTensor(output, retTys, retVals, errorMessage, &dimParams);
653652
if (ec) {
654653
errorMessage +=
655654
"Failed to import output tensor '" + output.name() + "'.\n";
@@ -854,13 +853,22 @@ class FrontendGenImpl {
854853
// Optional outputs using empty string.
855854
if (node.output()[i].empty()) {
856855
outputTypes.emplace_back(builder_.getNoneType());
857-
} else if (auto onnxModelType =
858-
ConvertOnnxType(node.output(i), errorMessage)) {
859-
if (auto ec = onnxModelType->getError()) {
860-
return ec;
861-
}
862-
outputTypes.emplace_back(*onnxModelType.value());
863856
} else {
857+
auto onnxModelType = ConvertOnnxType(node.output(i), errorMessage);
858+
if (onnxModelType) {
859+
const auto ec = onnxModelType->getError();
860+
if (!ec) {
861+
outputTypes.emplace_back(*onnxModelType.value());
862+
continue;
863+
}
864+
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
865+
errorMessage += "Failed to get type for '" + node.output(i) + "\n";
866+
return ec;
867+
}
868+
llvm::errs() << "Warning: "
869+
<< "Failed to get type type for '" << node.output(i)
870+
<< "', falling back to onnx-mlir based mapping.\n";
871+
}
864872
unsigned int j = i;
865873
// Variadic output is a single ODS result.
866874
if (variadicOut)
@@ -910,8 +918,8 @@ class FrontendGenImpl {
910918
region.push_back(new Block);
911919
OpBuilder::InsertionGuard guard(builder_);
912920
builder_.setInsertionPointToStart(&region.back());
913-
const ErrorOr<FunctionType> importGraphResult = importGraph(attr.g(),
914-
region, op, false, options_.allowMissingOutputTypes, errorMessage);
921+
const ErrorOr<FunctionType> importGraphResult =
922+
importGraph(attr.g(), region, op, false, errorMessage);
915923
if (auto ec = importGraphResult.getError()) {
916924
return ec;
917925
}
@@ -1443,9 +1451,14 @@ class FrontendGenImpl {
14431451
GetOpsetImportsFromProto(functionProto);
14441452

14451453
// Populates graph.value_info().
1446-
onnx::shape_inference::InferShapes(&graph, function_opset_map,
1447-
onnx::OpSchemaRegistry::Instance(),
1448-
/*options=*/{}, in_model_functions_);
1454+
try {
1455+
onnx::shape_inference::InferShapes(&graph, function_opset_map,
1456+
onnx::OpSchemaRegistry::Instance(),
1457+
/*options=*/{}, in_model_functions_);
1458+
} catch (const std::exception &e) {
1459+
llvm::errs() << "Warning: Caught exception running onnx shape inference: "
1460+
<< e.what() << "\n";
1461+
}
14491462

14501463
// Save caller context, while generating function body.
14511464
ModelLocalFunctionsMap callerModelFunctions;
@@ -1629,14 +1642,14 @@ class FrontendGenImpl {
16291642
const onnx::ValueInfoProto &output,
16301643
llvm::SmallVectorImpl<Type> &ret_types,
16311644
llvm::SmallVectorImpl<Value> &ret_vals, std::string &errorMessage,
1632-
bool allowMissingType, std::string *dim_params = nullptr) {
1645+
std::string *dim_params = nullptr) {
16331646
const Value *valPtr = frontend_symbols_.GetByOnnxName(output.name());
16341647
Value val = *valPtr;
16351648

16361649
ErrorOr<Type> parsedOutputType =
16371650
ImportType(output.type(), errorMessage, dim_params);
16381651
if (auto ec = parsedOutputType.getError()) {
1639-
if (!allowMissingType || ec != InvalidOnnxFormat) {
1652+
if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
16401653
errorMessage +=
16411654
"Failed to import output type for '" + output.name() + "\n";
16421655
return ec;
@@ -1716,15 +1729,10 @@ class FrontendGenImpl {
17161729
/*!
17171730
* Import ONNX main computation graph.
17181731
* @param graph onnx graph proto.
1719-
* @param allowMissingOutputTypes If true, type inference will be used to
1720-
* infer missing output types. This is done by copying the, potential
1721-
* inferred, output type of the node connected to the output. According to
1722-
* ONNX, all outputs MUST have types. Therefore this option has to be
1723-
* considered as a stretch best effort.
17241732
* @return A function corresponding to the imported computation graph.
17251733
*/
1726-
ErrorOr<func::FuncOp> importGraph(const onnx::GraphProto &graph,
1727-
bool allowMissingOutputTypes, std::string &errorMessage) {
1734+
ErrorOr<func::FuncOp> importGraph(
1735+
const onnx::GraphProto &graph, std::string &errorMessage) {
17281736
const std::string &name = "main_graph";
17291737
auto mainFunc = func::FuncOp::create(UnknownLoc(), name,
17301738
/*type=*/builder_.getFunctionType({}, {}), /*attrs=*/{});
@@ -1735,8 +1743,7 @@ class FrontendGenImpl {
17351743

17361744
ErrorOr<FunctionType> importedFuncType =
17371745
importGraph(graph, /*region=*/mainFunc.getBody(),
1738-
/*op=*/mainFunc.getOperation(), /*useReturn=*/true,
1739-
/*allowMissingOutputTypes=*/allowMissingOutputTypes, errorMessage);
1746+
/*op=*/mainFunc.getOperation(), /*useReturn=*/true, errorMessage);
17401747
if (auto ec = importedFuncType.getError()) {
17411748
errorMessage +=
17421749
"Failed to import main graph, could not get its function type\n";
@@ -1763,7 +1770,7 @@ class FrontendGenImpl {
17631770

17641771
[[nodiscard]] std::error_code ImportFrontendModelInternal(
17651772
onnx::ModelProto &model, MLIRContext &context,
1766-
OwningOpRef<ModuleOp> &module, ImportOptions options,
1773+
OwningOpRef<ModuleOp> &module, const ImportOptions &options,
17671774
std::string &errorMessage) {
17681775
int originVersion = CURRENT_ONNX_OPSET;
17691776
// Get the version of the model
@@ -1799,21 +1806,35 @@ class FrontendGenImpl {
17991806
originVersion < CURRENT_ONNX_OPSET) {
18001807
onnx::ModelProto convertModel =
18011808
onnx::version_conversion::ConvertVersion(model, CURRENT_ONNX_OPSET);
1802-
if (options.useOnnxModelTypes)
1803-
onnx::shape_inference::InferShapes(convertModel);
1809+
if (options.useOnnxModelTypes) {
1810+
try {
1811+
onnx::shape_inference::InferShapes(convertModel);
1812+
} catch (const std::exception &e) {
1813+
llvm::errs()
1814+
<< "Warning: Caught exception running onnx shape inference: "
1815+
<< e.what() << "\n";
1816+
}
1817+
}
18041818
return ImportFrontendModel(
18051819
convertModel, context, module, errorMessage, options);
18061820
} else {
1807-
if (options.useOnnxModelTypes)
1808-
onnx::shape_inference::InferShapes(model);
1821+
if (options.useOnnxModelTypes) {
1822+
try {
1823+
onnx::shape_inference::InferShapes(model);
1824+
} catch (const std::exception &e) {
1825+
llvm::errs()
1826+
<< "Warning: Caught exception running onnx shape inference: "
1827+
<< e.what() << "\n";
1828+
}
1829+
}
18091830
return ImportFrontendModel(model, context, module, errorMessage, options);
18101831
}
18111832
return CompilerSuccess;
18121833
}
18131834

18141835
[[nodiscard]] std::error_code ImportFrontendModelArray(const void *onnxBuffer,
18151836
int size, MLIRContext &context, OwningOpRef<ModuleOp> &module,
1816-
std::string &errorMessage, ImportOptions options) {
1837+
std::string &errorMessage, const ImportOptions &options) {
18171838
onnx::ModelProto model;
18181839

18191840
bool parse_success = model.ParseFromArray(onnxBuffer, size);
@@ -1855,7 +1876,7 @@ namespace {
18551876
// Return 0 on success, error otherwise.
18561877
[[nodiscard]] std::error_code ImportFrontendModelFile(StringRef model_fname,
18571878
MLIRContext &context, OwningOpRef<ModuleOp> &module,
1858-
std::string &errorMessage, ImportOptions options) {
1879+
std::string &errorMessage, const ImportOptions &options) {
18591880
onnx::ModelProto model;
18601881
if (model_fname.ends_with(".onnxtext")) {
18611882
std::string text;
@@ -1912,11 +1933,11 @@ namespace {
19121933

19131934
[[nodiscard]] std::error_code ImportFrontendModel(const onnx::ModelProto &model,
19141935
MLIRContext &context, OwningOpRef<ModuleOp> &module,
1915-
std::string &errorMessage, ImportOptions options) {
1936+
std::string &errorMessage, const ImportOptions &options) {
19161937

1917-
detail::FrontendGenImpl myONNXGen(context);
1938+
detail::FrontendGenImpl myONNXGen(context, options);
19181939
ErrorOr<ModuleOp> importedModule =
1919-
myONNXGen.ImportONNXModel(model, options, errorMessage);
1940+
myONNXGen.ImportONNXModel(model, errorMessage);
19201941
if (auto ec = importedModule.getError()) {
19211942
return ec;
19221943
}

src/Builder/FrontendDialectTransformer.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef ONNX_MLIR_FRONTEND_TRANSFORMER_H
1414
#define ONNX_MLIR_FRONTEND_TRANSFORMER_H
1515

16-
#include <set>
1716
#include <string>
1817

1918
#include "onnx/onnx_pb.h"
@@ -49,7 +48,11 @@ struct ImportOptions {
4948
bool allowSorting = true;
5049
bool useOutputNameAsLocation = false;
5150

52-
// Allow missing output types and use type inference to determine them.
51+
// If true, type inference will be used to
52+
// infer missing output types. This is done by copying the, potential
53+
// inferred, output type of the node connected to the output. According to
54+
// ONNX, all outputs MUST have types. Therefore this option has to be
55+
// considered as a stretch best effort.
5356
bool allowMissingOutputTypes = false;
5457

5558
// Custom shape information for the graph inputs.
@@ -90,7 +93,7 @@ struct ImportOptions {
9093
[[nodiscard]] std::error_code ImportFrontendModelArray(const void *onnxBuffer,
9194
int bufferSize, mlir::MLIRContext &context,
9295
mlir::OwningOpRef<mlir::ModuleOp> &module, std::string &errorMessage,
93-
ImportOptions options = ImportOptions());
96+
const ImportOptions &options = ImportOptions());
9497

9598
/*!
9699
* Import an ONNX model file into the ONNX Dialect.
@@ -100,7 +103,7 @@ struct ImportOptions {
100103
[[nodiscard]] std::error_code ImportFrontendModelFile(
101104
llvm::StringRef model_fname, mlir::MLIRContext &context,
102105
mlir::OwningOpRef<mlir::ModuleOp> &module, std::string &errorMessage,
103-
ImportOptions options = ImportOptions());
106+
const ImportOptions &options = ImportOptions());
104107

105108
/*!
106109
* Import an ONNX model proto into the ONNX Dialect.
@@ -109,7 +112,7 @@ struct ImportOptions {
109112
*/
110113
[[nodiscard]] std::error_code ImportFrontendModel(const onnx::ModelProto &model,
111114
mlir::MLIRContext &context, mlir::OwningOpRef<mlir::ModuleOp> &module,
112-
std::string &errorMessage, ImportOptions options = ImportOptions());
115+
std::string &errorMessage, const ImportOptions &options = ImportOptions());
113116

114117
/*!
115118
* TODO: Import models into other extension dialects that cover the

test/mlir/onnx/parse/add_missing_output_types.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
// INFERRED: return [[VAR_0_]], [[VAR_1_]] : tensor<*xf32>, tensor<3x3xf32>
1616
// INFERRED: }
1717

18+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=true --allowMissingOutputTypes=true --printIR %s | FileCheck --check-prefix=MODEL-TYPE %s
19+
// MODEL-TYPE-LABEL: func.func @main_graph
20+
// MODEL-TYPE-SAME: ([[PARAM_0_:%.+]]: tensor<3x1xf32> {onnx.name = "input_a"}, [[PARAM_1_:%.+]]: tensor<1x3xf32> {onnx.name = "input_b"}) -> (tensor<*xf32> {onnx.name = "output_c"}, tensor<3x3xf32> {onnx.name = "output_d"}) {
21+
// MODEL-TYPE-DAG: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {domain_name = "test", function_name = "test.Add", onnx_node_name = "add_node_custom"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<*xf32>
22+
// MODEL-TYPE-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) {onnx_node_name = "add_node"} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
23+
// MODEL-TYPE: return [[VAR_0_]], [[VAR_1_]] : tensor<*xf32>, tensor<3x3xf32>
24+
// MODEL-TYPE: }
25+
1826
{
1927
"irVersion": "10",
2028
"producerName": "onnx-example",

0 commit comments

Comments
 (0)