@@ -848,23 +848,33 @@ class FrontendGenImpl {
848848 // Use the type map or types in input model to determine the data type of
849849 // output.
850850 std::vector<int > outputMap = T::getTypeMap ();
851+ const bool shouldTakeShapeFromModelForCustomOp =
852+ isCustomOp &&
853+ (options_.useOnnxModelTypesForCustomOps || givenOutputTypes.empty ());
851854 for (unsigned int i = 0 ; i < (unsigned int )node.output ().size (); i++) {
852855 // Optional outputs using empty string.
853856 if (node.output ()[i].empty ()) {
854857 outputTypes.emplace_back (builder_.getNoneType ());
855858 } else {
856- if (options_.useOnnxModelTypes ||
857- (isCustomOp && options_.useOnnxModelTypesForCustomOps )) {
859+ if (options_.useOnnxModelTypes || shouldTakeShapeFromModelForCustomOp) {
858860 auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
859861 if (onnxModelType) {
860862 const auto ec = onnxModelType->getError ();
861863 if (!ec) {
862- outputTypes.emplace_back (*onnxModelType.value ());
864+ Type outputType = *onnxModelType.value ();
865+ if (!options_.useOnnxModelTypesForCustomOps &&
866+ !options_.useOnnxModelTypes ) {
867+ if (auto shapedType = mlir::dyn_cast<ShapedType>(outputType)) {
868+ Type elementType = shapedType.getElementType ();
869+ outputType = UnrankedTensorType::get (elementType);
870+ }
871+ }
872+ outputTypes.emplace_back (outputType);
863873 continue ;
864874 }
865875 if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866876 errorMessage +=
867- " Failed to get type for '" + node.output (i) + " \n " ;
877+ " Failed to get type for '" + node.output (i) + " ' \n " ;
868878 return ec;
869879 }
870880 llvm::errs () << " Warning: "
@@ -878,10 +888,15 @@ class FrontendGenImpl {
878888 j = 0 ;
879889 }
880890 if (!givenOutputTypes.empty ()) {
891+ assert (givenOutputTypes.size () > i &&
892+ " givenOutputTypes size is less than number of outputs" );
881893 outputTypes.emplace_back (
882894 UnrankedTensorType::get (givenOutputTypes[i]));
883895 } else if (j < outputMap.size () && outputMap[j] >= MAX_NUM_TYPES) {
884896 // Mapping gives a connection with an input.
897+ assert (
898+ outputMap[j] - MAX_NUM_TYPES < static_cast <int >(inputs.size ()) &&
899+ " output type mapping to input is out of range" );
885900 Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType ();
886901 if (mlir::isa<TensorType>(inputType)) {
887902 Type elementType =
0 commit comments