@@ -848,23 +848,33 @@ class FrontendGenImpl {
848848 // Use the type map or types in input model to determine the data type of
849849 // output.
850850 std::vector<int > outputMap = T::getTypeMap ();
851+ const bool shouldTakeShapeFromModelForCustomOp =
852+ isCustomOp &&
853+ (options_.useOnnxModelTypesForCustomOps || givenOutputTypes.empty ());
851854 for (unsigned int i = 0 ; i < (unsigned int )node.output ().size (); i++) {
852855 // Optional outputs using empty string.
853856 if (node.output ()[i].empty ()) {
854857 outputTypes.emplace_back (builder_.getNoneType ());
855858 } else {
856- if (options_.useOnnxModelTypes ||
857- (isCustomOp && options_.useOnnxModelTypesForCustomOps )) {
859+ if (options_.useOnnxModelTypes || shouldTakeShapeFromModelForCustomOp) {
858860 auto onnxModelType = ConvertOnnxType (node.output (i), errorMessage);
859861 if (onnxModelType) {
860862 const auto ec = onnxModelType->getError ();
861863 if (!ec) {
862- outputTypes.emplace_back (*onnxModelType.value ());
864+ Type outputType = *onnxModelType.value ();
865+ if (!options_.useOnnxModelTypesForCustomOps &&
866+ !options_.useOnnxModelTypes ) {
867+ if (auto shapedType = mlir::dyn_cast<ShapedType>(outputType)) {
868+ Type elementType = shapedType.getElementType ();
869+ outputType = UnrankedTensorType::get (elementType);
870+ }
871+ }
872+ outputTypes.emplace_back (outputType);
863873 continue ;
864874 }
865875 if (!options_.allowMissingOutputTypes || ec != InvalidOnnxFormat) {
866876 errorMessage +=
867- " Failed to get type for '" + node.output (i) + " \n " ;
877+ " Failed to get type for '" + node.output (i) + " ' \n " ;
868878 return ec;
869879 }
870880 llvm::errs () << " Warning: "
@@ -874,13 +884,19 @@ class FrontendGenImpl {
874884 }
875885 unsigned int j = i;
876886 // Variadic output is a single ODS result.
877- if (variadicOut)
887+ if (variadicOut) {
878888 j = 0 ;
889+ }
879890 if (!givenOutputTypes.empty ()) {
891+ assert (givenOutputTypes.size () > i &&
892+ " givenOutputTypes size is less than number of outputs" );
880893 outputTypes.emplace_back (
881894 UnrankedTensorType::get (givenOutputTypes[i]));
882895 } else if (j < outputMap.size () && outputMap[j] >= MAX_NUM_TYPES) {
883896 // Mapping gives a connection with an input.
897+ assert (
898+ outputMap[j] - MAX_NUM_TYPES < static_cast <int >(inputs.size ()) &&
899+ " output type mapping to input is out of range" );
884900 Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType ();
885901 if (mlir::isa<TensorType>(inputType)) {
886902 Type elementType =
@@ -1570,10 +1586,9 @@ class FrontendGenImpl {
15701586 auto domainAttr = builder_.getNamedAttr (
15711587 " domain_name" , builder_.getStringAttr (node.domain ()));
15721588 attributes.push_back (domainAttr);
1573- int nIn = 0 ;
1574- int nOut = 0 ;
1589+ const int nIn = ONNXCustomOp::getNumberOfOperands () ;
1590+ const int nOut = ONNXCustomOp::getNumberOfResults () ;
15751591 getNodeInputs (node, inputs);
1576- nOut = node.output ().size ();
15771592 std::vector<Type> givenOutputTypes;
15781593
15791594 // We lack a way of specifying import behavior for custom domains. For now
0 commit comments