@@ -1688,16 +1688,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
16881688 broadcastVecType.getShape ().take_back (extractResultRank))
16891689 return Value ();
16901690
1691- // The dim-1 broadcast -> ExtractOp folder requires in-place operation
1692- // modifications. For dynamic position, this means we have to change the
1693- // number of operands. This cannot be done in place since it changes the
1694- // operation storage. For dynamic dimensions, the dim-1 broadcasting should
1695- // be implemented as a canonicalization pattern.
1696- // TODO: Implement canonicalization pattern for dim-1 broadcasting +
1697- // extractop.
1698- if (extractOp.hasDynamicPosition ())
1699- return Value ();
1700-
17011691 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
17021692 int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
17031693
@@ -1706,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17061696 // extract position to `0` when extracting from the source operand.
17071697 llvm::SetVector<int64_t > broadcastedUnitDims =
17081698 broadcastOp.computeBroadcastedUnitDims ();
1709- SmallVector<int64_t > extractPos (extractOp.getStaticPosition ());
1699+ SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1700+ OpBuilder b (extractOp.getContext ());
17101701 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
17111702 for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
17121703 if (broadcastedUnitDims.contains (i))
1713- extractPos[i] = 0 ;
1704+ extractPos[i] = b. getIndexAttr ( 0 ) ;
17141705 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
17151706 // matching extract position when extracting from the source operand.
17161707 int64_t rankDiff = broadcastSrcRank - extractResultRank;
17171708 extractPos.erase (extractPos.begin (),
17181709 std::next (extractPos.begin (), extractPos.size () - rankDiff));
17191710 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1720- OpBuilder b (extractOp.getContext ());
1721- extractOp.setOperand (0 , source);
1722- extractOp.setStaticPosition (extractPos);
1711+ auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1712+ extractOp->setOperands (
1713+ llvm::to_vector (llvm::concat<Value>(ValueRange (source), dynPos)));
1714+ extractOp.setStaticPosition (staticPos);
17231715 return extractOp.getResult ();
17241716}
17251717
0 commit comments