@@ -5506,65 +5506,66 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55065506}
55075507
55085508LogicalResult ShapeCastOp::verify () {
5509- auto sourceVectorType =
5510- llvm::dyn_cast_or_null<VectorType>(getSource ().getType ());
5511- auto resultVectorType =
5512- llvm::dyn_cast_or_null<VectorType>(getResult ().getType ());
55135509
5514- if (!sourceVectorType)
5515- return failure ();
5516- if (!resultVectorType)
5517- return failure ();
5510+ VectorType sourceType = getSourceVectorType ();
5511+ VectorType resultType = getResultVectorType ();
55185512
5519- // Check that element type is the same.
5520- if (sourceVectorType.getElementType () != resultVectorType.getElementType ())
5521- return emitOpError (" source/result vectors must have same element type" );
5522- auto sourceShape = sourceVectorType.getShape ();
5523- auto resultShape = resultVectorType.getShape ();
5513+ // Check that element type is preserved
5514+ if (sourceType.getElementType () != resultType.getElementType ())
5515+ return emitOpError (" has different source and result element types" );
55245516
5525- // Check that product of source dim sizes matches product of result dim sizes.
5526- int64_t sourceDimProduct = std::accumulate (
5527- sourceShape.begin (), sourceShape.end (), 1LL , std::multiplies<int64_t >{});
5528- int64_t resultDimProduct = std::accumulate (
5529- resultShape.begin (), resultShape.end (), 1LL , std::multiplies<int64_t >{});
5530- if (sourceDimProduct != resultDimProduct)
5531- return emitOpError (" source/result number of elements must match" );
5517+ // Check that number of elements is preserved
5518+ int64_t sourceNElms = sourceType.getNumElements ();
5519+ int64_t resultNElms = resultType.getNumElements ();
5520+ if (sourceNElms != resultNElms) {
5521+ return emitOpError () << " has different number of elements at source ("
5522+ << sourceNElms << " ) and result (" << resultNElms
5523+ << " )" ;
5524+ }
55325525
55335526 // Check that (non-)scalability is preserved
5534- int64_t sourceNScalableDims = sourceVectorType .getNumScalableDims ();
5535- int64_t resultNScalableDims = resultVectorType .getNumScalableDims ();
5527+ int64_t sourceNScalableDims = sourceType .getNumScalableDims ();
5528+ int64_t resultNScalableDims = resultType .getNumScalableDims ();
55365529 if (sourceNScalableDims != resultNScalableDims)
5537- return emitOpError (" different number of scalable dims at source (" )
5538- << sourceNScalableDims << " ) and result (" << resultNScalableDims
5539- << " )" ;
5540- sourceVectorType.getNumDynamicDims ();
5530+ return emitOpError () << " has different number of scalable dims at source ("
5531+ << sourceNScalableDims << " ) and result ("
5532+ << resultNScalableDims << " )" ;
55415533
55425534 return success ();
55435535}
55445536
55455537OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
55465538
5539+ VectorType resultType = getType ();
5540+
55475541 // No-op shape cast.
5548- if (getSource ().getType () == getType () )
5542+ if (getSource ().getType () == resultType )
55495543 return getSource ();
55505544
5551- VectorType resultType = getType ();
5552-
5553- // Canceling shape casts .
5545+ // Y = shape_cast(shape_cast(X)))
5546+ // -> X, if X and Y have same type
5547+ // -> shape_cast(X) otherwise .
55545548 if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5555-
5556- // Only allows valid transitive folding (expand/collapse dimensions).
55575549 VectorType srcType = otherOp.getSource ().getType ();
55585550 if (resultType == srcType)
55595551 return otherOp.getSource ();
55605552 setOperand (otherOp.getSource ());
55615553 return getResult ();
55625554 }
55635555
5564- // Cancelling broadcast and shape cast ops.
5556+ // Y = shape_cast(broadcast(X))
5557+ // -> X, if X and Y have same type, else
5558+ // -> shape_cast(X) if X is a vector and the broadcast preserves
5559+ // number of elements.
55655560 if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
55665561 if (bcastOp.getSourceType () == resultType)
55675562 return bcastOp.getSource ();
5563+ if (auto bcastSrcType = dyn_cast<VectorType>(bcastOp.getSourceType ())) {
5564+ if (bcastSrcType.getNumElements () == resultType.getNumElements ()) {
5565+ setOperand (bcastOp.getSource ());
5566+ return getResult ();
5567+ }
5568+ }
55685569 }
55695570
55705571 // shape_cast(constant) -> constant
0 commit comments