@@ -890,7 +890,7 @@ LogicalResult tosa::SliceOp::verify() {
890890 if (!inputType || !outputType)
891891 return success ();
892892
893- if (inputType.getRank () != outputType.getRank ()) {
893+ if (inputType.getRank () != outputType.getRank ()) {
894894 return emitOpError () << " rank of input (" << inputType.getRank ()
895895 << " ) and output (" << outputType.getRank ()
896896 << " ) must match" ;
@@ -1087,34 +1087,35 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
10871087 return emitOpError () << " cannot reshape " << inputElementsNum
10881088 << " elements into " << outputElementsNum;
10891089 }
1090+
1091+ if ((int64_t )getNewShape ().size () != outputType.getRank ()) {
1092+ return emitOpError ()
1093+ << " rank of newShape (" << getNewShape ().size ()
1094+ << " ) and output (" << outputType.getRank () << " ) must match" ;
1095+ }
1096+
1097+ for (int64_t dim = 0 ; dim < outputType.getRank (); ++dim) {
1098+ if (getNewShape ()[dim] != -1 &&
1099+ getNewShape ()[dim] != outputType.getShape ()[dim]) {
1100+ return emitOpError ()
1101+ << " newShape attribute (" << getNewShape ()[dim]
1102+ << " ) does not match output type ("
1103+ << outputType.getShape ()[dim] << " ) in dimension " << dim;
1104+ }
1105+ }
10901106 }
10911107
1108+ // AMD: Switched checks with > to >= to allow zero dimensions
10921109 int64_t newShapeElementsNum = std::accumulate (
10931110 getNewShape ().begin (), getNewShape ().end (), 1LL ,
1094- [](int64_t acc, int64_t dim) { return (dim > 0 ) ? acc * dim : acc; });
1111+ [](int64_t acc, int64_t dim) { return (dim >= 0 ) ? acc * dim : acc; });
10951112 bool isStaticNewShape =
1096- llvm::all_of (getNewShape (), [](int64_t s) { return s > 0 ; });
1113+ llvm::all_of (getNewShape (), [](int64_t s) { return s >= 0 ; });
10971114 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
10981115 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
10991116 return emitOpError () << " cannot reshape " << inputElementsNum
11001117 << " elements into " << newShapeElementsNum;
11011118 }
1102-
1103- if ((int64_t )getNewShape ().size () != outputType.getRank ()) {
1104- return emitOpError () << " rank of newShape (" << getNewShape ().size ()
1105- << " ) and output (" << outputType.getRank ()
1106- << " ) must match" ;
1107- }
1108-
1109- for (int64_t dim = 0 ; dim < outputType.getRank (); ++dim) {
1110- if (getNewShape ()[dim] != -1 &&
1111- getNewShape ()[dim] != outputType.getShape ()[dim]) {
1112- return emitOpError ()
1113- << " newShape attribute (" << getNewShape ()[dim]
1114- << " ) does not match output type (" << outputType.getShape ()[dim]
1115- << " ) in dimension " << dim;
1116- }
1117- }
11181119 }
11191120
11201121 int missingDims = llvm::count (getNewShape (), -1 );
0 commit comments