@@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) {
17071707 llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
17081708}
17091709
1710- // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711- // / 1s, are considered 'broadcastlike'.
1710+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711+ // / 1s, are considered to be 'broadcastlike'.
17121712static bool isBroadcastLike (Operation *op) {
17131713 if (isa<BroadcastOp, SplatOp>(op))
17141714 return true ;
@@ -1717,61 +1717,97 @@ static bool isBroadcastLike(Operation *op) {
17171717 if (!shapeCast)
17181718 return false ;
17191719
1720- // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1720+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
17211721 // Condition 1: dst has hight rank.
17221722 // Condition 2: src shape is a suffix of dst shape.
1723+ //
1724+ // Note that checking that dst shape has a prefix of 1s is not sufficient,
1725+ // for example (2,3) -> (1,3,2) is not broadcast-like.
17231726 VectorType srcType = shapeCast.getSourceVectorType ();
17241727 ArrayRef<int64_t > srcShape = srcType.getShape ();
17251728 uint64_t srcRank = srcType.getRank ();
17261729 ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
17271730 return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
17281731}
17291732
1730- // / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1733+ // / Fold extract(broadcast(X)) to either extract(X) or just X.
1734+ // /
1735+ // / Example:
1736+ // /
1737+ // / broadcast extract
1738+ // / (3, 4) --------> (2, 3, 4) ------> (4)
1739+ // /
1740+ // / becomes
1741+ // / extract
1742+ // / (3,4) ---------------------------> (4)
1743+ // /
1744+ // /
1745+ // / The variable names used in this implementation use names which correspond to
1746+ // / the above shapes as,
1747+ // /
1748+ // / - (3, 4) is `input` shape.
1749+ // / - (2, 3, 4) is `broadcast` shape.
1750+ // / - (4) is `extract` shape.
1751+ // /
1752+ // / This folding is possible when the suffix of `input` shape is the same as
1753+ // / `extract` shape.
17311754static Value foldExtractFromBroadcast (ExtractOp extractOp) {
17321755
1733- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1734- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ))
1756+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
1757+ if (!defOp || !isBroadcastLike (defOp ))
17351758 return Value ();
17361759
1737- Value src = broadcastLikeOp ->getOperand (0 );
1760+ Value input = defOp ->getOperand (0 );
17381761
17391762 // Replace extract(broadcast(X)) with X
1740- if (extractOp.getType () == src .getType ())
1741- return src ;
1763+ if (extractOp.getType () == input .getType ())
1764+ return input ;
17421765
17431766 // Get required types and ranks in the chain
1744- // src -> broadcastDst -> dst
1745- auto srcType = llvm::dyn_cast<VectorType>(src .getType ());
1746- auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1747- unsigned srcRank = srcType ? srcType .getRank () : 0 ;
1748- unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1749- unsigned dstRank = dstType ? dstType .getRank () : 0 ;
1767+ // input -> broadcast -> extract
1768+ auto inputType = llvm::dyn_cast<VectorType>(input .getType ());
1769+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1770+ unsigned inputRank = inputType ? inputType .getRank () : 0 ;
1771+ unsigned broadcastRank = extractOp.getSourceVectorType ().getRank ();
1772+ unsigned extractRank = extractType ? extractType .getRank () : 0 ;
17501773
17511774 // Cannot do without the broadcast if overall the rank increases.
1752- if (dstRank > srcRank )
1775+ if (extractRank > inputRank )
17531776 return Value ();
17541777
1755- assert (srcType && " src must be a vector type because of previous checks" );
1756-
1757- ArrayRef<int64_t > srcShape = srcType.getShape ();
1758- if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
1778+ // Proof by contradiction that, at this point, input is a vector.
1779+ // Suppose input is a scalar.
1780+ // ==> inputRank is 0.
1781+ // ==> extractRank is 0 (because extractRank <= inputRank).
1782+ // ==> extract is scalar (because rank-0 extraction is always scalar).
1783+ // ==> input and extract are scalar, so same type.
1784+ // ==> returned early (check same type).
1785+ // Contradiction!
1786+ assert (inputType && " input must be a vector type because of previous checks" );
1787+ ArrayRef<int64_t > inputShape = inputType.getShape ();
1788+
1789+ // In the case where there is a broadcast dimension in the suffix, it is not
1790+ // possible to replace extract(broadcast(X)) with extract(X). Example:
1791+ //
1792+ // broadcast extract
1793+ // (1) --------> (3,4) ------> (4)
1794+ if (extractType &&
1795+ extractType.getShape () != inputShape.take_back (extractRank))
17591796 return Value ();
17601797
17611798 // Replace extract(broadcast(X)) with extract(X).
17621799 // First, determine the new extraction position.
1763- unsigned deltaOverall = srcRank - dstRank;
1764- unsigned deltaBroadcast = broadcastDstRank - srcRank;
1765-
1800+ unsigned deltaOverall = inputRank - extractRank;
1801+ unsigned deltaBroadcast = broadcastRank - inputRank;
17661802 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
17671803 SmallVector<OpFoldResult> newPositions (deltaOverall);
17681804 IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1769- for (auto [i, size] : llvm::enumerate (srcShape .take_front (deltaOverall))) {
1805+ for (auto [i, size] : llvm::enumerate (inputShape .take_front (deltaOverall))) {
17701806 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
17711807 }
17721808 auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
17731809 extractOp->setOperands (
1774- llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
1810+ llvm::to_vector (llvm::concat<Value>(ValueRange (input ), dynPos)));
17751811 extractOp.setStaticPosition (staticPos);
17761812 return extractOp.getResult ();
17771813}
@@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22172253 LogicalResult matchAndRewrite (ExtractOp extractOp,
22182254 PatternRewriter &rewriter) const override {
22192255
2220- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2256+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
22212257 VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2222- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ) || !outType)
2258+ if (!defOp || !isBroadcastLike (defOp ) || !outType)
22232259 return failure ();
22242260
2225- Value source = broadcastLikeOp ->getOperand (0 );
2261+ Value source = defOp ->getOperand (0 );
22262262 if (isBroadcastableTo (source.getType (), outType) !=
22272263 BroadcastableToResult::Success)
22282264 return failure ();
0 commit comments