@@ -210,6 +210,27 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
210210 }
211211}
212212
213+ // ===----------------------------------------------------------------------===//
214+ // TOSA shape inference helper
215+ // ===----------------------------------------------------------------------===//
216+ bool mlir::tosa::collectShapeValue (Operation *op,
217+ llvm::SmallVector<int64_t > &newShape) {
218+ if (!op) {
219+ return false ;
220+ }
221+ if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
222+ Attribute constOpAttr = constOp->getAttr (" value" );
223+ DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
224+ for (int i = 0 ; i < elementsAttr.size (); i++) {
225+ int64_t val = elementsAttr.getValues <int64_t >()[i];
226+ newShape.push_back (val);
227+ }
228+ return true ;
229+ }
230+ // for undefined op, return false.
231+ return false ;
232+ }
233+
213234// ===----------------------------------------------------------------------===//
214235// TOSA Operator Verifiers.
215236// ===----------------------------------------------------------------------===//
@@ -823,51 +844,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
823844 PadOp::Adaptor adaptor,
824845 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
825846 ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
826- ShapeAdaptor paddingShape (adaptor.getPadding ().getType ());
847+ auto paddingRank =
848+ cast<tosa::shapeType>(adaptor.getPadding ().getType ()).getRank ();
827849 SmallVector<int64_t > outputShape;
828850
829- // If both inputs have unknown shape, we cannot determine the shape of the
830- // output.
831- if (!inputShape.hasRank () && !paddingShape.hasRank ()) {
832- inferredReturnShapes.push_back (ShapedTypeComponents ());
833- return success ();
834- }
835-
836- // If the input rank is unknown we can info the output rank using the
837- // padding shape's first dim.
851+ // If the input rank is unknown, we can infer the output rank using the
852+ // padding shape's rank divided by 2.
838853 if (!inputShape.hasRank ()) {
839- if (paddingShape.isDynamicDim (0 )) {
840- inferredReturnShapes.push_back (ShapedTypeComponents ());
841- return success ();
842- }
843-
844- outputShape.resize (paddingShape.getDimSize (0 ) / 2 , ShapedType::kDynamic );
854+ outputShape.resize (paddingRank / 2 , ShapedType::kDynamic );
845855 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
846856 return success ();
847857 }
848858
849- DenseIntElementsAttr paddings ;
859+ SmallVector< int64_t > paddingValues ;
850860 // If the paddings value is not a constant, all dimensions must be dynamic.
851- if (!matchPattern (adaptor.getPadding (), m_Constant (&paddings))) {
861+ if (!tosa::collectShapeValue (adaptor.getPadding ().getDefiningOp (),
862+ paddingValues)) {
852863 outputShape.resize (inputShape.getRank (), ShapedType::kDynamic );
853864 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
854865 return success ();
855866 }
856867
857- SmallVector<int64_t > paddingValues;
858- for (auto val : paddings) {
859- paddingValues.push_back (val.getSExtValue ());
860- }
861-
862868 outputShape.reserve (inputShape.getRank ());
863869 for (int i = 0 , s = inputShape.getRank (); i < s; i++) {
864870 if (inputShape.isDynamicDim (i)) {
865871 outputShape.push_back (ShapedType::kDynamic );
866872 continue ;
867873 }
874+ auto padFront = paddingValues[i * 2 ];
875+ auto padBack = paddingValues[i * 2 + 1 ];
876+ if (padFront < 0 || padBack < 0 ) {
877+ // if either padding for dim i is -1, output dim is unknown
878+ outputShape.push_back (ShapedType::kDynamic );
879+ continue ;
880+ }
868881
869- outputShape.push_back (inputShape.getDimSize (i) + paddingValues[i * 2 ] +
870- paddingValues[i * 2 + 1 ]);
882+ outputShape.push_back (inputShape.getDimSize (i) + padFront + padBack);
871883 }
872884
873885 inferredReturnShapes.push_back (ShapedTypeComponents (outputShape));
@@ -877,17 +889,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
877889LogicalResult tosa::PadOp::verify () {
878890 RankedTensorType inputType = getInput1 ().getType ();
879891 RankedTensorType outputType = getOutput ().getType ();
880- RankedTensorType paddingType = getPadding ().getType ();
892+ auto paddingRank = cast<tosa::shapeType>( getPadding ().getType ()). getRank ();
881893
882894 if (inputType.getRank () != outputType.getRank ())
883895 return emitOpError () << " expect same input and output tensor rank." ;
884896
885- if (!paddingType.isDynamicDim (0 ) &&
886- paddingType.getDimSize (0 ) != inputType.getRank () * 2 )
897+ if (paddingRank != inputType.getRank () * 2 )
887898 return emitOpError () << " expected padding tensor dim 0 to have size "
888899 << inputType.getRank () * 2
889- << " (2*rank(shape1)) but got size "
890- << paddingType.getDimSize (0 );
900+ << " (2*rank(shape1)) but got size " << paddingRank;
891901
892902 return success ();
893903}
0 commit comments