@@ -85,16 +85,21 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
8585template <typename ReshapeOpTy, typename InverseReshapeOpTy>
8686static OpFoldResult foldReshapeOp (ReshapeOpTy reshapeOp,
8787 ArrayRef<Attribute> operands) {
88- // Fold producer-consumer reshape ops that where the operand type of the
88+
89+ if (reshapeOp.getSrcType () == reshapeOp.getType ())
90+ return reshapeOp.getSrc ();
91+
92+ // Fold producer-consumer reshape ops where the operand type of the
8993 // producer is same as the return type of the consumer.
9094 auto reshapeSrcOp =
9195 reshapeOp.getSrc ().template getDefiningOp <InverseReshapeOpTy>();
9296 if (reshapeSrcOp && reshapeSrcOp.getSrcType () == reshapeOp.getResultType ())
9397 return reshapeSrcOp.getSrc ();
98+
9499 // Reshape of a constant can be replaced with a new constant.
95- if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front ())) {
100+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front ()))
96101 return elements.reshape (cast<ShapedType>(reshapeOp.getResult ().getType ()));
97- }
102+
98103 return nullptr ;
99104}
100105
@@ -103,41 +108,36 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
103108template <typename Op, typename T>
104109static LogicalResult verifyReshapeLikeTypes (Op op, T expandedType,
105110 T collapsedType, bool isExpansion) {
111+
106112 unsigned expandedRank = expandedType.getRank ();
107113 unsigned collapsedRank = collapsedType.getRank ();
108114 if (expandedRank < collapsedRank)
109- return op.emitOpError (" expected the type " )
110- << expandedType
111- << " to have higher rank than the type = " << collapsedType;
112- if (expandedRank == 0 )
113- return op.emitOpError (" expected non-zero memref ranks" );
114- if (expandedRank == collapsedRank)
115- return op.emitOpError (" expected to collapse or expand dims" );
116-
117- if (collapsedRank == 0 ) {
118- // If collapsed rank is 0, then expanded type must be static shaped and of
119- // sizes 1.
120- if (llvm::any_of (expandedType.getShape (),
121- [](int64_t dim) -> bool { return dim != 1 ; }))
122- return op.emitOpError (" invalid to reshape tensor/memref with non-unit "
123- " extent dimensions to zero-rank tensor/memref" );
124- return success ();
125- }
115+ return op.emitOpError (" expected the expanded type, " )
116+ << expandedType << " to have a higher (or same) rank "
117+ << " than the collapsed type, " << collapsedType << ' .' ;
118+
126119 if (collapsedRank != op.getReassociation ().size ())
127- return op.emitOpError (" expected rank of the collapsed type(" )
128- << collapsedRank << " ) to be the number of reassociation maps("
129- << op.getReassociation ().size () << " )" ;
120+ return op.emitOpError (" expected collapsed rank (" )
121+ << collapsedRank << " ) to equal the number of reassociation maps ("
122+ << op.getReassociation ().size () << " )." ;
123+
130124 auto maps = op.getReassociationMaps ();
131125 for (auto it : llvm::enumerate (maps))
132126 if (it.value ().getNumDims () != expandedRank)
133127 return op.emitOpError (" expected reassociation map #" )
134- << it.index () << " of same rank as expanded memref("
135- << expandedRank << " ), but got " << it.value ().getNumDims ();
128+ << it.index () << " to have size equal to the expanded rank ("
129+ << expandedRank << " ), but it is " << it.value ().getNumDims ()
130+ << ' .' ;
131+
136132 int invalidIdx = 0 ;
137133 if (!isReassociationValid (maps, &invalidIdx))
138134 return op.emitOpError (" expected reassociation map #" )
139- << invalidIdx << " to be valid and contiguous" ;
140- return verifyReshapeLikeShapes (op, collapsedType, expandedType, isExpansion);
135+ << invalidIdx << " to be valid and contiguous." ;
136+
137+ return reshapeLikeShapesAreCompatible (
138+ [&](const Twine &msg) { return op->emitOpError (msg); },
139+ collapsedType.getShape (), expandedType.getShape (),
140+ op.getReassociationIndices (), isExpansion);
141141}
142142
143143// / Verify that shapes of the reshaped types using following rules
@@ -153,16 +153,6 @@ LogicalResult reshapeLikeShapesAreCompatible(
153153 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
154154 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
155155
156- template <typename OpTy>
157- static LogicalResult verifyReshapeLikeShapes (OpTy op, ShapedType collapsedType,
158- ShapedType expandedType,
159- bool isExpandingReshape) {
160- return reshapeLikeShapesAreCompatible (
161- [&](const Twine &msg) { return op->emitOpError (msg); },
162- collapsedType.getShape (), expandedType.getShape (),
163- op.getReassociationIndices (), isExpandingReshape);
164- }
165-
166156// / Returns true iff the type is a MemRefType and has a non-identity layout.
167157bool hasNonIdentityLayout (Type type);
168158
0 commit comments