@@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55345534
55355535// / Returns true if each element of 'a' is equal to the product of a contiguous
55365536// / sequence of the elements of 'b'. Returns false otherwise.
5537- static bool isValidShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
5537+ static bool isValidExpandingShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
55385538 unsigned rankA = a.size ();
55395539 unsigned rankB = b.size ();
5540- assert (rankA < rankB);
5540+ assert (rankA <= rankB);
55415541
55425542 auto isOne = [](int64_t v) { return v == 1 ; };
55435543
@@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55735573 VectorType resultVectorType) {
55745574 // Check that element type is the same.
55755575 if (sourceVectorType.getElementType () != resultVectorType.getElementType ())
5576- return op->emitOpError (" source/result vectors must have same element type" );
5577- auto sourceShape = sourceVectorType.getShape ();
5578- auto resultShape = resultVectorType.getShape ();
5576+ return op->emitOpError (" has different source and result element types" );
5577+ ArrayRef<int64_t > lowRankShape = sourceVectorType.getShape ();
5578+ ArrayRef<int64_t > highRankShape = resultVectorType.getShape ();
5579+ if (lowRankShape.size () > highRankShape.size ())
5580+ std::swap (lowRankShape, highRankShape);
55795581
55805582 // Check that product of source dim sizes matches product of result dim sizes.
5581- int64_t sourceDimProduct = std::accumulate (
5582- sourceShape.begin (), sourceShape.end (), 1LL , std::multiplies<int64_t >{});
5583- int64_t resultDimProduct = std::accumulate (
5584- resultShape.begin (), resultShape.end (), 1LL , std::multiplies<int64_t >{});
5585- if (sourceDimProduct != resultDimProduct)
5586- return op->emitOpError (" source/result number of elements must match" );
5587-
5588- // Check that expanding/contracting rank cases.
5589- unsigned sourceRank = sourceVectorType.getRank ();
5590- unsigned resultRank = resultVectorType.getRank ();
5591- if (sourceRank < resultRank) {
5592- if (!isValidShapeCast (sourceShape, resultShape))
5593- return op->emitOpError (" invalid shape cast" );
5594- } else if (sourceRank > resultRank) {
5595- if (!isValidShapeCast (resultShape, sourceShape))
5596- return op->emitOpError (" invalid shape cast" );
5583+ int64_t nLowRankElms =
5584+ std::accumulate (lowRankShape.begin (), lowRankShape.end (), 1LL ,
5585+ std::multiplies<int64_t >{});
5586+ int64_t nHighRankElms =
5587+ std::accumulate (highRankShape.begin (), highRankShape.end (), 1LL ,
5588+ std::multiplies<int64_t >{});
5589+
5590+ if (nLowRankElms != nHighRankElms) {
5591+ return op->emitOpError (
5592+ " has a different number of source and result elements" );
5593+ }
5594+
5595+ if (!isValidExpandingShapeCast (lowRankShape, highRankShape)) {
5596+ return op->emitOpError (
5597+ " is invalid (does not uniformly collapse or expand)" );
55975598 }
55985599
55995600 // Check that (non-)scalability is preserved
56005601 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims ();
56015602 int64_t resultNScalableDims = resultVectorType.getNumScalableDims ();
56025603 if (sourceNScalableDims != resultNScalableDims)
5603- return op->emitOpError (" different number of scalable dims at source (" )
5604+ return op->emitOpError (
5605+ " has a different number of scalable dims at source (" )
56045606 << sourceNScalableDims << " ) and result (" << resultNScalableDims
56055607 << " )" ;
56065608 sourceVectorType.getNumDynamicDims ();
@@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56345636
56355637 // Only allows valid transitive folding (expand/collapse dimensions).
56365638 VectorType srcType = otherOp.getSource ().getType ();
5639+
56375640 if (resultType == srcType)
56385641 return otherOp.getSource ();
5639- if (srcType. getRank () < resultType. getRank ()) {
5640- if (! isValidShapeCast ( srcType.getShape (), resultType. getShape ()))
5641- return {} ;
5642- } else if (srcType. getRank () > resultType. getRank ()) {
5643- if (! isValidShapeCast (resultType. getShape (), srcType. getShape ()))
5644- return {};
5645- } else {
5642+
5643+ ArrayRef< int64_t > lowRankShape = srcType.getShape ();
5644+ ArrayRef< int64_t > highRankShape = resultType. getShape () ;
5645+ if (lowRankShape. size () > highRankShape. size ())
5646+ std::swap (lowRankShape, highRankShape);
5647+
5648+ if (! isValidExpandingShapeCast (lowRankShape, highRankShape))
56465649 return {};
5647- }
5650+
56485651 setOperand (otherOp.getSource ());
56495652 return getResult ();
56505653 }
0 commit comments