@@ -159,20 +159,19 @@ using SymbolToOnnxTypeMapping = SymbolMapping<onnx::TypeProto>;
159159
160160class FrontendGenImpl {
161161public:
162- explicit FrontendGenImpl (MLIRContext &context)
163- : context_(context), builder_(&context) {
162+ explicit FrontendGenImpl (MLIRContext &context, const ImportOptions &options )
163+ : options_(options), context_(context), builder_(&context) {
164164 module_ = ModuleOp::create (UnknownLoc::get (&context));
165165 InitHandlerMap ();
166166 }
167167
168- ErrorOr<ModuleOp> ImportONNXModel (const onnx::ModelProto &model,
169- ImportOptions options, std::string &errorMessage) {
170- options_ = options;
168+ ErrorOr<ModuleOp> ImportONNXModel (
169+ const onnx::ModelProto &model, std::string &errorMessage) {
171170 modelInputShaper_.setShapeInformation (options_.shapeInformation );
172171 opset_map_ = GetOpsetImportsFromProto (model); // Which opsets to use.
173172 in_model_functions_ = GetModelLocalFunctions (model);
174- ErrorOr<mlir::func::FuncOp> importGraphResult = importGraph (
175- model.graph (), options_. allowMissingOutputTypes , errorMessage);
173+ ErrorOr<mlir::func::FuncOp> importGraphResult =
174+ importGraph ( model.graph (), errorMessage);
176175 if (auto ec = importGraphResult.getError ()) {
177176 return ec;
178177 }
@@ -193,7 +192,7 @@ class FrontendGenImpl {
193192 }
194193
195194private:
196- ImportOptions options_;
195+ const ImportOptions & options_;
197196 MLIRContext &context_;
198197 ModuleOp module_;
199198 OpBuilder builder_;
@@ -541,7 +540,7 @@ class FrontendGenImpl {
541540 */
542541 ErrorOr<FunctionType> importGraph (const onnx::GraphProto &graph,
543542 Region ®ion, Operation *op, bool useReturn,
544- bool allowMissingOutputTypes, std::string &errorMessage) {
543+ std::string &errorMessage) {
545544 frontend_symbols_.pushScope (graph.name ());
546545 onnx_type_map.pushScope (graph.name ());
547546 Block *entryBlock = ®ion.back ();
@@ -648,8 +647,8 @@ class FrontendGenImpl {
648647 // Import the output tensors
649648 for (const auto &output : graph.output ()) {
650649 std::string dimParams = " " ;
651- const auto ec = ImportOutputTensor (output, retTys, retVals, errorMessage,
652- allowMissingOutputTypes , &dimParams);
650+ const auto ec =
651+ ImportOutputTensor (output, retTys, retVals, errorMessage , &dimParams);
653652 if (ec) {
654653 errorMessage +=
655654 " Failed to import output tensor '" + output.name () + " '.\n " ;
@@ -854,13 +853,22 @@ class FrontendGenImpl {
854853 // Optional outputs using empty string.
855854 if (node.output ()[i].empty ()) {
856855 outputTypes.emplace_back (builder_.getNoneType ());
857- } else if (auto onnxModelType =
858- ConvertOnnxType (node.output (i), errorMessage)) {
859- if (auto ec = onnxModelType->getError ()) {
860- return ec;
861- }
862- outputTypes.emplace_back (*onnxModelType.value ());
863856 } else {
857+ auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
858+ if (onnxModelType) {
859+ const auto ec = onnxModelType->getError ();
860+ if (!ec) {
861+ outputTypes.emplace_back (*onnxModelType.value ());
862+ continue ;
863+ }
864+ if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
865+ errorMessage += " Failed to get type for '" + node.output (i) + " \n " ;
866+ return ec;
867+ }
868+ llvm::errs () << " Warning: "
869+ << " Failed to get type type for '" << node.output (i)
870+ << " ', falling back to onnx-mlir based mapping.\n " ;
871+ }
864872 unsigned int j = i;
865873 // Variadic output is a single ODS result.
866874 if (variadicOut)
@@ -910,8 +918,8 @@ class FrontendGenImpl {
910918 region.push_back (new Block);
911919 OpBuilder::InsertionGuard guard (builder_);
912920 builder_.setInsertionPointToStart (®ion.back ());
913- const ErrorOr<FunctionType> importGraphResult = importGraph (attr. g (),
914- region, op, false , options_. allowMissingOutputTypes , errorMessage);
921+ const ErrorOr<FunctionType> importGraphResult =
922+ importGraph (attr. g (), region, op, false , errorMessage);
915923 if (auto ec = importGraphResult.getError ()) {
916924 return ec;
917925 }
@@ -1443,9 +1451,14 @@ class FrontendGenImpl {
14431451 GetOpsetImportsFromProto (functionProto);
14441452
14451453 // Populates graph.value_info().
1446- onnx::shape_inference::InferShapes (&graph, function_opset_map,
1447- onnx::OpSchemaRegistry::Instance (),
1448- /* options=*/ {}, in_model_functions_);
1454+ try {
1455+ onnx::shape_inference::InferShapes (&graph, function_opset_map,
1456+ onnx::OpSchemaRegistry::Instance (),
1457+ /* options=*/ {}, in_model_functions_);
1458+ } catch (const std::exception &e) {
1459+ llvm::errs () << " Warning: Caught exception running onnx shape inference: "
1460+ << e.what () << " \n " ;
1461+ }
14491462
14501463 // Save caller context, while generating function body.
14511464 ModelLocalFunctionsMap callerModelFunctions;
@@ -1629,14 +1642,14 @@ class FrontendGenImpl {
16291642 const onnx::ValueInfoProto &output,
16301643 llvm::SmallVectorImpl<Type> &ret_types,
16311644 llvm::SmallVectorImpl<Value> &ret_vals, std::string &errorMessage,
1632- bool allowMissingType, std::string *dim_params = nullptr ) {
1645+ std::string *dim_params = nullptr ) {
16331646 const Value *valPtr = frontend_symbols_.GetByOnnxName (output.name ());
16341647 Value val = *valPtr;
16351648
16361649 ErrorOr<Type> parsedOutputType =
16371650 ImportType (output.type (), errorMessage, dim_params);
16381651 if (auto ec = parsedOutputType.getError ()) {
1639- if (!allowMissingType || ec != InvalidOnnxFormat) {
1652+ if (!options_. allowMissingOutputTypes || ec != InvalidOnnxFormat) {
16401653 errorMessage +=
16411654 " Failed to import output type for '" + output.name () + " \n " ;
16421655 return ec;
@@ -1716,15 +1729,10 @@ class FrontendGenImpl {
17161729 /* !
17171730 * Import ONNX main computation graph.
17181731 * @param graph onnx graph proto.
1719- * @param allowMissingOutputTypes If true, type inference will be used to
1720- * infer missing output types. This is done by copying the, potential
1721- * inferred, output type of the node connected to the output. According to
1722- * ONNX, all outputs MUST have types. Therefore this option has to be
1723- * considered as a stretch best effort.
17241732 * @return A function corresponding to the imported computation graph.
17251733 */
1726- ErrorOr<func::FuncOp> importGraph (const onnx::GraphProto &graph,
1727- bool allowMissingOutputTypes , std::string &errorMessage) {
1734+ ErrorOr<func::FuncOp> importGraph (
1735+ const onnx::GraphProto &graph , std::string &errorMessage) {
17281736 const std::string &name = " main_graph" ;
17291737 auto mainFunc = func::FuncOp::create (UnknownLoc (), name,
17301738 /* type=*/ builder_.getFunctionType ({}, {}), /* attrs=*/ {});
@@ -1735,8 +1743,7 @@ class FrontendGenImpl {
17351743
17361744 ErrorOr<FunctionType> importedFuncType =
17371745 importGraph (graph, /* region=*/ mainFunc.getBody (),
1738- /* op=*/ mainFunc.getOperation (), /* useReturn=*/ true ,
1739- /* allowMissingOutputTypes=*/ allowMissingOutputTypes, errorMessage);
1746+ /* op=*/ mainFunc.getOperation (), /* useReturn=*/ true , errorMessage);
17401747 if (auto ec = importedFuncType.getError ()) {
17411748 errorMessage +=
17421749 " Failed to import main graph, could not get its function type\n " ;
@@ -1763,7 +1770,7 @@ class FrontendGenImpl {
17631770
17641771[[nodiscard]] std::error_code ImportFrontendModelInternal (
17651772 onnx::ModelProto &model, MLIRContext &context,
1766- OwningOpRef<ModuleOp> &module , ImportOptions options,
1773+ OwningOpRef<ModuleOp> &module , const ImportOptions & options,
17671774 std::string &errorMessage) {
17681775 int originVersion = CURRENT_ONNX_OPSET;
17691776 // Get the version of the model
@@ -1799,21 +1806,35 @@ class FrontendGenImpl {
17991806 originVersion < CURRENT_ONNX_OPSET) {
18001807 onnx::ModelProto convertModel =
18011808 onnx::version_conversion::ConvertVersion (model, CURRENT_ONNX_OPSET);
1802- if (options.useOnnxModelTypes )
1803- onnx::shape_inference::InferShapes (convertModel);
1809+ if (options.useOnnxModelTypes ) {
1810+ try {
1811+ onnx::shape_inference::InferShapes (convertModel);
1812+ } catch (const std::exception &e) {
1813+ llvm::errs ()
1814+ << " Warning: Caught exception running onnx shape inference: "
1815+ << e.what () << " \n " ;
1816+ }
1817+ }
18041818 return ImportFrontendModel (
18051819 convertModel, context, module , errorMessage, options);
18061820 } else {
1807- if (options.useOnnxModelTypes )
1808- onnx::shape_inference::InferShapes (model);
1821+ if (options.useOnnxModelTypes ) {
1822+ try {
1823+ onnx::shape_inference::InferShapes (model);
1824+ } catch (const std::exception &e) {
1825+ llvm::errs ()
1826+ << " Warning: Caught exception running onnx shape inference: "
1827+ << e.what () << " \n " ;
1828+ }
1829+ }
18091830 return ImportFrontendModel (model, context, module , errorMessage, options);
18101831 }
18111832 return CompilerSuccess;
18121833}
18131834
18141835[[nodiscard]] std::error_code ImportFrontendModelArray (const void *onnxBuffer,
18151836 int size, MLIRContext &context, OwningOpRef<ModuleOp> &module ,
1816- std::string &errorMessage, ImportOptions options) {
1837+ std::string &errorMessage, const ImportOptions & options) {
18171838 onnx::ModelProto model;
18181839
18191840 bool parse_success = model.ParseFromArray (onnxBuffer, size);
@@ -1855,7 +1876,7 @@ namespace {
18551876// Return 0 on success, error otherwise.
18561877[[nodiscard]] std::error_code ImportFrontendModelFile (StringRef model_fname,
18571878 MLIRContext &context, OwningOpRef<ModuleOp> &module ,
1858- std::string &errorMessage, ImportOptions options) {
1879+ std::string &errorMessage, const ImportOptions & options) {
18591880 onnx::ModelProto model;
18601881 if (model_fname.ends_with (" .onnxtext" )) {
18611882 std::string text;
@@ -1912,11 +1933,11 @@ namespace {
19121933
19131934[[nodiscard]] std::error_code ImportFrontendModel (const onnx::ModelProto &model,
19141935 MLIRContext &context, OwningOpRef<ModuleOp> &module ,
1915- std::string &errorMessage, ImportOptions options) {
1936+ std::string &errorMessage, const ImportOptions & options) {
19161937
1917- detail::FrontendGenImpl myONNXGen (context);
1938+ detail::FrontendGenImpl myONNXGen (context, options );
19181939 ErrorOr<ModuleOp> importedModule =
1919- myONNXGen.ImportONNXModel (model, options, errorMessage);
1940+ myONNXGen.ImportONNXModel (model, errorMessage);
19201941 if (auto ec = importedModule.getError ()) {
19211942 return ec;
19221943 }
0 commit comments