@@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) {
17071707 llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
17081708}
17091709
1710+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711+ // / 1s, are considered 'broadcastlike'.
1712+ static bool isBroadcastLike (Operation *op) {
1713+ if (isa<BroadcastOp, SplatOp>(op))
1714+ return true ;
1715+
1716+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717+ if (!shapeCast)
1718+ return false ;
1719+
1720+ // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1721+ // Condition 1: dst has hight rank.
1722+ // Condition 2: src shape is a suffix of dst shape.
1723+ VectorType srcType = shapeCast.getSourceVectorType ();
1724+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1725+ uint64_t srcRank = srcType.getRank ();
1726+ ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1727+ return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1728+ }
1729+
17101730// / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
17111731static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1712- Operation *defOp = extractOp.getVector ().getDefiningOp ();
1713- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1732+
1733+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1734+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp))
17141735 return Value ();
17151736
1716- Value source = defOp->getOperand (0 );
1717- if (extractOp.getType () == source.getType ())
1718- return source;
1719- auto getRank = [](Type type) {
1720- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank ()
1721- : 0 ;
1722- };
1737+ Value src = broadcastLikeOp->getOperand (0 );
1738+
1739+ // Replace extract(broadcast(X)) with X
1740+ if (extractOp.getType () == src.getType ())
1741+ return src;
17231742
1724- // If splat or broadcast from a scalar, just return the source scalar.
1725- unsigned broadcastSrcRank = getRank (source.getType ());
1726- if (broadcastSrcRank == 0 && source.getType () == extractOp.getType ())
1727- return source;
1743+ // Get required types and ranks in the chain
1744+ // src -> broadcastDst -> dst
1745+ auto srcType = llvm::dyn_cast<VectorType>(src.getType ());
1746+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1747+ unsigned srcRank = srcType ? srcType.getRank () : 0 ;
1748+ unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1749+ unsigned dstRank = dstType ? dstType.getRank () : 0 ;
17281750
1729- unsigned extractResultRank = getRank (extractOp. getType ());
1730- if (extractResultRank > broadcastSrcRank )
1751+ // Cannot do without the broadcast if overall the rank increases.
1752+ if (dstRank > srcRank )
17311753 return Value ();
1732- // Check that the dimension of the result haven't been broadcasted.
1733- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1734- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType ());
1735- if (extractVecType && broadcastVecType &&
1736- extractVecType.getShape () !=
1737- broadcastVecType.getShape ().take_back (extractResultRank))
1754+
1755+ assert (srcType && " src must be a vector type because of previous checks" );
1756+
1757+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1758+ if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
17381759 return Value ();
17391760
1740- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741- int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1761+ // Replace extract(broadcast(X)) with extract(X).
1762+ // First, determine the new extraction position.
1763+ unsigned deltaOverall = srcRank - dstRank;
1764+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
17421765
1743- // Detect all the positions that come from "dim-1" broadcasting.
1744- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745- // extract position to `0` when extracting from the source operand.
1746- llvm::SetVector<int64_t > broadcastedUnitDims =
1747- broadcastOp.computeBroadcastedUnitDims ();
1748- SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1749- OpBuilder b (extractOp.getContext ());
1750- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751- for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1752- if (broadcastedUnitDims.contains (i))
1753- extractPos[i] = b.getIndexAttr (0 );
1754- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755- // matching extract position when extracting from the source operand.
1756- int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757- extractPos.erase (extractPos.begin (),
1758- std::next (extractPos.begin (), extractPos.size () - rankDiff));
1759- // OpBuilder is only used as a helper to build an I64ArrayAttr.
1760- auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1766+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1767+ SmallVector<OpFoldResult> newPositions (deltaOverall);
1768+ IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1769+ for (auto [i, size] : llvm::enumerate (srcShape.take_front (deltaOverall))) {
1770+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1771+ }
1772+ auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
17611773 extractOp->setOperands (
1762- llvm::to_vector (llvm::concat<Value>(ValueRange (source ), dynPos)));
1774+ llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
17631775 extractOp.setStaticPosition (staticPos);
17641776 return extractOp.getResult ();
17651777}
@@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22042216
22052217 LogicalResult matchAndRewrite (ExtractOp extractOp,
22062218 PatternRewriter &rewriter) const override {
2207- Operation *defOp = extractOp.getVector ().getDefiningOp ();
2208- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2209- return failure ();
22102219
2211- Value source = defOp->getOperand (0 );
2212- if (extractOp.getType () == source.getType ())
2220+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2221+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2222+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp) || !outType)
22132223 return failure ();
2214- auto getRank = [](Type type) {
2215- return llvm::isa<VectorType>(type)
2216- ? llvm::cast<VectorType>(type).getRank ()
2217- : 0 ;
2218- };
2219- unsigned broadcastSrcRank = getRank (source.getType ());
2220- unsigned extractResultRank = getRank (extractOp.getType ());
2221- // We only consider the case where the rank of the source is less than or
2222- // equal to the rank of the extract dst. The other cases are handled in the
2223- // folding patterns.
2224- if (extractResultRank < broadcastSrcRank)
2225- return failure ();
2226- // For scalar result, the input can only be a rank-0 vector, which will
2227- // be handled by the folder.
2228- if (extractResultRank == 0 )
2224+
2225+ Value source = broadcastLikeOp->getOperand (0 );
2226+ if (isBroadcastableTo (source.getType (), outType) !=
2227+ BroadcastableToResult::Success)
22292228 return failure ();
22302229
2231- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
2232- extractOp, extractOp.getType (), source);
2230+ rewriter.replaceOpWithNewOp <BroadcastOp>(extractOp, outType, source);
22332231 return success ();
22342232 }
22352233};
0 commit comments