@@ -5505,48 +5505,18 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55055505 setResultRanges (getResult (), argRanges.front ());
55065506}
55075507
5508- // / Returns true if each element of 'a' is equal to the product of a contiguous
5509- // / sequence of the elements of 'b'. Returns false otherwise.
5510- static bool isValidShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
5511- unsigned rankA = a.size ();
5512- unsigned rankB = b.size ();
5513- assert (rankA < rankB);
5514-
5515- auto isOne = [](int64_t v) { return v == 1 ; };
5516-
5517- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5518- // casted to a 0-d vector.
5519- if (rankA == 0 && llvm::all_of (b, isOne))
5520- return true ;
5521-
5522- unsigned i = 0 ;
5523- unsigned j = 0 ;
5524- while (i < rankA && j < rankB) {
5525- int64_t dimA = a[i];
5526- int64_t dimB = 1 ;
5527- while (dimB < dimA && j < rankB)
5528- dimB *= b[j++];
5529- if (dimA != dimB)
5530- break ;
5531- ++i;
5532-
5533- // Handle the case when trailing dimensions are of size 1.
5534- // Include them into the contiguous sequence.
5535- if (i < rankA && llvm::all_of (a.slice (i), isOne))
5536- i = rankA;
5537- if (j < rankB && llvm::all_of (b.slice (j), isOne))
5538- j = rankB;
5539- }
5508+ LogicalResult ShapeCastOp::verify () {
5509+ auto sourceVectorType =
5510+ llvm::dyn_cast_or_null<VectorType>(getSource ().getType ());
5511+ auto resultVectorType =
5512+ llvm::dyn_cast_or_null<VectorType>(getResult ().getType ());
55405513
5541- return i == rankA && j == rankB ;
5542- }
5514+ if (!sourceVectorType) return failure () ;
5515+ if (!resultVectorType) return failure ();
55435516
5544- static LogicalResult verifyVectorShapeCast (Operation *op,
5545- VectorType sourceVectorType,
5546- VectorType resultVectorType) {
55475517 // Check that element type is the same.
55485518 if (sourceVectorType.getElementType () != resultVectorType.getElementType ())
5549- return op-> emitOpError (" source/result vectors must have same element type" );
5519+ return emitOpError (" source/result vectors must have same element type" );
55505520 auto sourceShape = sourceVectorType.getShape ();
55515521 auto resultShape = resultVectorType.getShape ();
55525522
@@ -5556,44 +5526,20 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55565526 int64_t resultDimProduct = std::accumulate (
55575527 resultShape.begin (), resultShape.end (), 1LL , std::multiplies<int64_t >{});
55585528 if (sourceDimProduct != resultDimProduct)
5559- return op->emitOpError (" source/result number of elements must match" );
5560-
5561- // Check that expanding/contracting rank cases.
5562- unsigned sourceRank = sourceVectorType.getRank ();
5563- unsigned resultRank = resultVectorType.getRank ();
5564- if (sourceRank < resultRank) {
5565- if (!isValidShapeCast (sourceShape, resultShape))
5566- return op->emitOpError (" invalid shape cast" );
5567- } else if (sourceRank > resultRank) {
5568- if (!isValidShapeCast (resultShape, sourceShape))
5569- return op->emitOpError (" invalid shape cast" );
5570- }
5529+ return emitOpError (" source/result number of elements must match" );
55715530
55725531 // Check that (non-)scalability is preserved
55735532 int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims ();
55745533 int64_t resultNScalableDims = resultVectorType.getNumScalableDims ();
55755534 if (sourceNScalableDims != resultNScalableDims)
5576- return op-> emitOpError (" different number of scalable dims at source (" )
5535+ return emitOpError (" different number of scalable dims at source (" )
55775536 << sourceNScalableDims << " ) and result (" << resultNScalableDims
55785537 << " )" ;
55795538 sourceVectorType.getNumDynamicDims ();
55805539
55815540 return success ();
55825541}
55835542
5584- LogicalResult ShapeCastOp::verify () {
5585- auto sourceVectorType =
5586- llvm::dyn_cast_or_null<VectorType>(getSource ().getType ());
5587- auto resultVectorType =
5588- llvm::dyn_cast_or_null<VectorType>(getResult ().getType ());
5589-
5590- // Check if source/result are of vector type.
5591- if (sourceVectorType && resultVectorType)
5592- return verifyVectorShapeCast (*this , sourceVectorType, resultVectorType);
5593-
5594- return success ();
5595- }
5596-
55975543OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
55985544
55995545 // No-op shape cast.
@@ -5609,15 +5555,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56095555 VectorType srcType = otherOp.getSource ().getType ();
56105556 if (resultType == srcType)
56115557 return otherOp.getSource ();
5612- if (srcType.getRank () < resultType.getRank ()) {
5613- if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5614- return {};
5615- } else if (srcType.getRank () > resultType.getRank ()) {
5616- if (!isValidShapeCast (resultType.getShape (), srcType.getShape ()))
5617- return {};
5618- } else {
5619- return {};
5620- }
56215558 setOperand (otherOp.getSource ());
56225559 return getResult ();
56235560 }
0 commit comments