@@ -5546,124 +5546,56 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55465546 setResultRanges (getResult (), argRanges.front ());
55475547}
55485548
5549- // / Returns true if each element of 'a' is equal to the product of a contiguous
5550- // / sequence of the elements of 'b'. Returns false otherwise.
5551- static bool isValidShapeCast (ArrayRef<int64_t > a, ArrayRef<int64_t > b) {
5552- unsigned rankA = a.size ();
5553- unsigned rankB = b.size ();
5554- assert (rankA < rankB);
5555-
5556- auto isOne = [](int64_t v) { return v == 1 ; };
5557-
5558- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5559- // casted to a 0-d vector.
5560- if (rankA == 0 && llvm::all_of (b, isOne))
5561- return true ;
5549+ LogicalResult ShapeCastOp::verify () {
55625550
5563- unsigned i = 0 ;
5564- unsigned j = 0 ;
5565- while (i < rankA && j < rankB) {
5566- int64_t dimA = a[i];
5567- int64_t dimB = 1 ;
5568- while (dimB < dimA && j < rankB)
5569- dimB *= b[j++];
5570- if (dimA != dimB)
5571- break ;
5572- ++i;
5551+ VectorType sourceType = getSourceVectorType ();
5552+ VectorType resultType = getResultVectorType ();
55735553
5574- // Handle the case when trailing dimensions are of size 1.
5575- // Include them into the contiguous sequence.
5576- if (i < rankA && llvm::all_of (a.slice (i), isOne))
5577- i = rankA;
5578- if (j < rankB && llvm::all_of (b.slice (j), isOne))
5579- j = rankB;
5580- }
5554+ // Check that element type is preserved
5555+ if (sourceType.getElementType () != resultType.getElementType ())
5556+ return emitOpError (" has different source and result element types" );
55815557
5582- return i == rankA && j == rankB;
5583- }
5584-
5585- static LogicalResult verifyVectorShapeCast (Operation *op,
5586- VectorType sourceVectorType,
5587- VectorType resultVectorType) {
5588- // Check that element type is the same.
5589- if (sourceVectorType.getElementType () != resultVectorType.getElementType ())
5590- return op->emitOpError (" source/result vectors must have same element type" );
5591- auto sourceShape = sourceVectorType.getShape ();
5592- auto resultShape = resultVectorType.getShape ();
5593-
5594- // Check that product of source dim sizes matches product of result dim sizes.
5595- int64_t sourceDimProduct = std::accumulate (
5596- sourceShape.begin (), sourceShape.end (), 1LL , std::multiplies<int64_t >{});
5597- int64_t resultDimProduct = std::accumulate (
5598- resultShape.begin (), resultShape.end (), 1LL , std::multiplies<int64_t >{});
5599- if (sourceDimProduct != resultDimProduct)
5600- return op->emitOpError (" source/result number of elements must match" );
5601-
5602- // Check that expanding/contracting rank cases.
5603- unsigned sourceRank = sourceVectorType.getRank ();
5604- unsigned resultRank = resultVectorType.getRank ();
5605- if (sourceRank < resultRank) {
5606- if (!isValidShapeCast (sourceShape, resultShape))
5607- return op->emitOpError (" invalid shape cast" );
5608- } else if (sourceRank > resultRank) {
5609- if (!isValidShapeCast (resultShape, sourceShape))
5610- return op->emitOpError (" invalid shape cast" );
5558+ // Check that number of elements is preserved
5559+ int64_t sourceNElms = sourceType.getNumElements ();
5560+ int64_t resultNElms = resultType.getNumElements ();
5561+ if (sourceNElms != resultNElms) {
5562+ return emitOpError () << " has different number of elements at source ("
5563+ << sourceNElms << " ) and result (" << resultNElms
5564+ << " )" ;
56115565 }
56125566
56135567 // Check that (non-)scalability is preserved
5614- int64_t sourceNScalableDims = sourceVectorType .getNumScalableDims ();
5615- int64_t resultNScalableDims = resultVectorType .getNumScalableDims ();
5568+ int64_t sourceNScalableDims = sourceType .getNumScalableDims ();
5569+ int64_t resultNScalableDims = resultType .getNumScalableDims ();
56165570 if (sourceNScalableDims != resultNScalableDims)
5617- return op->emitOpError (" different number of scalable dims at source (" )
5618- << sourceNScalableDims << " ) and result (" << resultNScalableDims
5619- << " )" ;
5620- sourceVectorType.getNumDynamicDims ();
5621-
5622- return success ();
5623- }
5624-
5625- LogicalResult ShapeCastOp::verify () {
5626- auto sourceVectorType =
5627- llvm::dyn_cast_or_null<VectorType>(getSource ().getType ());
5628- auto resultVectorType =
5629- llvm::dyn_cast_or_null<VectorType>(getResult ().getType ());
5630-
5631- // Check if source/result are of vector type.
5632- if (sourceVectorType && resultVectorType)
5633- return verifyVectorShapeCast (*this , sourceVectorType, resultVectorType);
5571+ return emitOpError () << " has different number of scalable dims at source ("
5572+ << sourceNScalableDims << " ) and result ("
5573+ << resultNScalableDims << " )" ;
56345574
56355575 return success ();
56365576}
56375577
56385578OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
56395579
5580+ VectorType resultType = getType ();
5581+
56405582 // No-op shape cast.
5641- if (getSource ().getType () == getType () )
5583+ if (getSource ().getType () == resultType )
56425584 return getSource ();
56435585
5644- VectorType resultType = getType ();
5645-
5646- // Canceling shape casts .
5586+ // Y = shape_cast(shape_cast(X)))
5587+ // -> X, if X and Y have same type
5588+ // -> shape_cast(X) otherwise .
56475589 if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5648-
5649- // Only allows valid transitive folding (expand/collapse dimensions).
56505590 VectorType srcType = otherOp.getSource ().getType ();
56515591 if (resultType == srcType)
56525592 return otherOp.getSource ();
5653- if (srcType.getRank () < resultType.getRank ()) {
5654- if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5655- return {};
5656- } else if (srcType.getRank () > resultType.getRank ()) {
5657- if (!isValidShapeCast (resultType.getShape (), srcType.getShape ()))
5658- return {};
5659- } else {
5660- return {};
5661- }
56625593 setOperand (otherOp.getSource ());
56635594 return getResult ();
56645595 }
56655596
5666- // Cancelling broadcast and shape cast ops.
5597+ // Y = shape_cast(broadcast(X))
5598+ // -> X, if X and Y have same type
56675599 if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
56685600 if (bcastOp.getSourceType () == resultType)
56695601 return bcastOp.getSource ();
0 commit comments