@@ -414,15 +414,13 @@ class FrontendGenImpl {
414414
415415 std::optional<ErrorOr<Type>> ConvertOnnxType (
416416 const std::string &onnx_name, std::string &errorMessage) {
417- if (options_.useOnnxModelTypes ) {
418- if (const onnx::TypeProto *onnxTypePtr =
419- onnx_type_map.GetByOnnxName (onnx_name)) {
420- ErrorOr<Type> importedType = ImportType (*onnxTypePtr, errorMessage);
421- if (auto ec = importedType.getError ()) {
422- return ec;
423- }
424- return *importedType;
417+ if (const onnx::TypeProto *onnxTypePtr =
418+ onnx_type_map.GetByOnnxName (onnx_name)) {
419+ ErrorOr<Type> importedType = ImportType (*onnxTypePtr, errorMessage);
420+ if (auto ec = importedType.getError ()) {
421+ return ec;
425422 }
423+ return *importedType;
426424 }
427425 return std::nullopt ;
428426 }
@@ -827,7 +825,8 @@ class FrontendGenImpl {
827825 const onnx::NodeProto &node, std::vector<Value> inputs,
828826 int expectedNumOperands, int expectedNumResults,
829827 const std::vector<NamedAttribute> &attributes, std::string &errorMessage,
830- std::vector<Type> givenOutputTypes = std::vector<Type>()) {
828+ std::vector<Type> givenOutputTypes = std::vector<Type>(),
829+ bool isCustomOp = false) {
831830 bool variadicIn = expectedNumOperands == -1 ;
832831 bool variadicOut = expectedNumResults == -1 ;
833832
@@ -854,20 +853,24 @@ class FrontendGenImpl {
854853 if (node.output ()[i].empty ()) {
855854 outputTypes.emplace_back (builder_.getNoneType ());
856855 } else {
857- auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
858- if (onnxModelType) {
859- const auto ec = onnxModelType->getError ();
860- if (!ec) {
861- outputTypes.emplace_back (*onnxModelType.value ());
862- continue ;
863- }
864- if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
865- errorMessage += " Failed to get type for '" + node.output (i) + " \n " ;
866- return ec;
856+ if (options_.useOnnxModelTypes ||
857+ (isCustomOp && options_.useOnnxModelTypesForCustomOps )) {
858+ auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
859+ if (onnxModelType) {
860+ const auto ec = onnxModelType->getError ();
861+ if (!ec) {
862+ outputTypes.emplace_back (*onnxModelType.value ());
863+ continue ;
864+ }
865+ if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866+ errorMessage +=
867+ " Failed to get type for '" + node.output (i) + " \n " ;
868+ return ec;
869+ }
870+ llvm::errs () << " Warning: "
871+ << " Failed to get type type for '" << node.output (i)
872+ << " ', falling back to onnx-mlir based mapping.\n " ;
867873 }
868- llvm::errs () << " Warning: "
869- << " Failed to get type type for '" << node.output (i)
870- << " ', falling back to onnx-mlir based mapping.\n " ;
871874 }
872875 unsigned int j = i;
873876 // Variadic output is a single ODS result.
@@ -931,6 +934,8 @@ class FrontendGenImpl {
931934 }
932935 }
933936 }
937+ // Note: ResultTypeInferenceOpInterface only infers the type of the result,
938+ // not the shape
934939 if (auto opWithTypeInference =
935940 mlir::dyn_cast<ResultTypeInferenceOpInterface>(op.getOperation ())) {
936941 auto outTypes = opWithTypeInference.resultTypeInference ();
@@ -1456,7 +1461,8 @@ class FrontendGenImpl {
14561461 onnx::OpSchemaRegistry::Instance (),
14571462 /* options=*/ {}, in_model_functions_);
14581463 } catch (const std::exception &e) {
1459- llvm::errs () << " Warning: Caught exception running onnx shape inference: "
1464+ llvm::errs () << " Warning: Caught exception running onnx shape inference "
1465+ " to populate graph.value_info: "
14601466 << e.what () << " \n " ;
14611467 }
14621468
@@ -1584,8 +1590,8 @@ class FrontendGenImpl {
15841590
15851591 // ToFix: The type inference may go wrong if the element type of the output
15861592 // of CustomOp is not the same as the first input.
1587- return buildOutputAndOperation<ONNXCustomOp>(
1588- node, inputs, nIn, nOut, attributes, errorMessage, givenOutputTypes);
1593+ return buildOutputAndOperation<ONNXCustomOp>(node, inputs, nIn, nOut,
1594+ attributes, errorMessage, givenOutputTypes, /* isCustomOp */ true );
15891595 }
15901596
15911597 [[nodiscard]] std::error_code ImportNode (
@@ -1806,7 +1812,7 @@ class FrontendGenImpl {
18061812 originVersion < CURRENT_ONNX_OPSET) {
18071813 onnx::ModelProto convertModel =
18081814 onnx::version_conversion::ConvertVersion (model, CURRENT_ONNX_OPSET);
1809- if (options.useOnnxModelTypes ) {
1815+ if (options.runOnnxShapeInference ) {
18101816 try {
18111817 onnx::shape_inference::InferShapes (convertModel);
18121818 } catch (const std::exception &e) {
@@ -1818,7 +1824,7 @@ class FrontendGenImpl {
18181824 return ImportFrontendModel (
18191825 convertModel, context, module , errorMessage, options);
18201826 } else {
1821- if (options.useOnnxModelTypes ) {
1827+ if (options.runOnnxShapeInference ) {
18221828 try {
18231829 onnx::shape_inference::InferShapes (model);
18241830 } catch (const std::exception &e) {
0 commit comments