@@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) {
17071707 llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
17081708}
17091709
1710- // / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1710+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711+ // / 1s, are considered to be 'broadcastlike'.
1712+ static bool isBroadcastLike (Operation *op) {
1713+ if (isa<BroadcastOp, SplatOp>(op))
1714+ return true ;
1715+
1716+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717+ if (!shapeCast)
1718+ return false ;
1719+
1720+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721+ // Checking that the destination shape has a prefix of 1s is not sufficient,
1722+ // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1723+ // is that the source shape is a suffix of the destination shape.
1724+ VectorType srcType = shapeCast.getSourceVectorType ();
1725+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1726+ uint64_t srcRank = srcType.getRank ();
1727+ ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1728+ return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1729+ }
1730+
1731+ // / Fold extract(broadcast(X)) to either extract(X) or just X.
1732+ // /
1733+ // / Example:
1734+ // /
1735+ // / broadcast extract [1][2]
1736+ // / (3, 4) --------> (2, 3, 4) ----------------> (4)
1737+ // /
1738+ // / becomes
1739+ // / extract [1]
1740+ // / (3,4) -------------------------------------> (4)
1741+ // /
1742+ // /
1743+ // / The variable names used in this implementation correspond to the above
1744+ // / shapes as,
1745+ // /
1746+ // / - (3, 4) is `input` shape.
1747+ // / - (2, 3, 4) is `broadcast` shape.
1748+ // / - (4) is `extract` shape.
1749+ // /
1750+ // / This folding is possible when the suffix of `input` shape is the same as
1751+ // / `extract` shape.
17111752static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1753+
17121754 Operation *defOp = extractOp.getVector ().getDefiningOp ();
1713- if (!defOp || !isa<vector::BroadcastOp, SplatOp> (defOp))
1755+ if (!defOp || !isBroadcastLike (defOp))
17141756 return Value ();
17151757
1716- Value source = defOp->getOperand (0 );
1717- if (extractOp.getType () == source.getType ())
1718- return source;
1719- auto getRank = [](Type type) {
1720- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank ()
1721- : 0 ;
1722- };
1758+ Value input = defOp->getOperand (0 );
17231759
1724- // If splat or broadcast from a scalar, just return the source scalar.
1725- unsigned broadcastSrcRank = getRank (source.getType ());
1726- if (broadcastSrcRank == 0 && source.getType () == extractOp.getType ())
1727- return source;
1760+ // Replace extract(broadcast(X)) with X
1761+ if (extractOp.getType () == input.getType ())
1762+ return input;
17281763
1729- unsigned extractResultRank = getRank (extractOp.getType ());
1730- if (extractResultRank > broadcastSrcRank)
1731- return Value ();
1732- // Check that the dimension of the result haven't been broadcasted.
1733- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1734- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType ());
1735- if (extractVecType && broadcastVecType &&
1736- extractVecType.getShape () !=
1737- broadcastVecType.getShape ().take_back (extractResultRank))
1764+ // Get required types and ranks in the chain
1765+ // input -> broadcast -> extract
1766+ // (scalars are treated as rank-0).
1767+ auto inputType = llvm::dyn_cast<VectorType>(input.getType ());
1768+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1769+ unsigned inputRank = inputType ? inputType.getRank () : 0 ;
1770+ unsigned broadcastRank = extractOp.getSourceVectorType ().getRank ();
1771+ unsigned extractRank = extractType ? extractType.getRank () : 0 ;
1772+
1773+ // Cannot do without the broadcast if overall the rank increases.
1774+ if (extractRank > inputRank)
17381775 return Value ();
17391776
1740- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741- int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1777+ // The above condition guarantees that input is a vector.
1778+ assert (inputType && " input must be a vector type because of previous checks" );
1779+ ArrayRef<int64_t > inputShape = inputType.getShape ();
17421780
1743- // Detect all the positions that come from "dim-1" broadcasting.
1744- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745- // extract position to `0` when extracting from the source operand.
1746- llvm::SetVector<int64_t > broadcastedUnitDims =
1747- broadcastOp.computeBroadcastedUnitDims ();
1748- SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1749- OpBuilder b (extractOp.getContext ());
1750- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751- for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1752- if (broadcastedUnitDims.contains (i))
1753- extractPos[i] = b.getIndexAttr (0 );
1754- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755- // matching extract position when extracting from the source operand.
1756- int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757- extractPos.erase (extractPos.begin (),
1758- std::next (extractPos.begin (), extractPos.size () - rankDiff));
1759- // OpBuilder is only used as a helper to build an I64ArrayAttr.
1760- auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1781+ // In the case where there is a broadcast dimension in the suffix, it is not
1782+ // possible to replace extract(broadcast(X)) with extract(X). Example:
1783+ //
1784+ // broadcast extract
1785+ // (1) --------> (3,4) ------> (4)
1786+ if (extractType &&
1787+ extractType.getShape () != inputShape.take_back (extractRank))
1788+ return Value ();
1789+
1790+ // Replace extract(broadcast(X)) with extract(X).
1791+ // First, determine the new extraction position.
1792+ unsigned deltaOverall = inputRank - extractRank;
1793+ unsigned deltaBroadcast = broadcastRank - inputRank;
1794+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1795+ SmallVector<OpFoldResult> newPositions (deltaOverall);
1796+ IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1797+ for (auto [i, size] : llvm::enumerate (inputShape.take_front (deltaOverall))) {
1798+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1799+ }
1800+ auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
17611801 extractOp->setOperands (
1762- llvm::to_vector (llvm::concat<Value>(ValueRange (source ), dynPos)));
1802+ llvm::to_vector (llvm::concat<Value>(ValueRange (input ), dynPos)));
17631803 extractOp.setStaticPosition (staticPos);
17641804 return extractOp.getResult ();
17651805}
@@ -2204,32 +2244,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22042244
22052245 LogicalResult matchAndRewrite (ExtractOp extractOp,
22062246 PatternRewriter &rewriter) const override {
2247+
22072248 Operation *defOp = extractOp.getVector ().getDefiningOp ();
2208- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2249+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2250+ if (!defOp || !isBroadcastLike (defOp) || !outType)
22092251 return failure ();
22102252
22112253 Value source = defOp->getOperand (0 );
2212- if (extractOp.getType () == source.getType ())
2213- return failure ();
2214- auto getRank = [](Type type) {
2215- return llvm::isa<VectorType>(type)
2216- ? llvm::cast<VectorType>(type).getRank ()
2217- : 0 ;
2218- };
2219- unsigned broadcastSrcRank = getRank (source.getType ());
2220- unsigned extractResultRank = getRank (extractOp.getType ());
2221- // We only consider the case where the rank of the source is less than or
2222- // equal to the rank of the extract dst. The other cases are handled in the
2223- // folding patterns.
2224- if (extractResultRank < broadcastSrcRank)
2225- return failure ();
2226- // For scalar result, the input can only be a rank-0 vector, which will
2227- // be handled by the folder.
2228- if (extractResultRank == 0 )
2254+ if (isBroadcastableTo (source.getType (), outType) !=
2255+ BroadcastableToResult::Success)
22292256 return failure ();
22302257
2231- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
2232- extractOp, extractOp.getType (), source);
2258+ rewriter.replaceOpWithNewOp <BroadcastOp>(extractOp, outType, source);
22332259 return success ();
22342260 }
22352261};
0 commit comments