@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506 return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
15071507}
15081508
1509+ // / Determines whether the mask for a corresponding `vector.transfer_write` op
1510+ // / is trivially foldable (i.e., guaranteed to be all true).
1511+ // /
1512+ // / Requirements:
1513+ // / * All involved shapes (destination, mask) are static.
1514+ // / * All write indices are constant.
1515+ // / * All mask sizes are constant.
1516+ // /
1517+ // / Once verified, the method checks for each destination dimension `d`:
1518+ // / (1) destDimSize[rankDiff + d] <= maskShape[d]
1519+ // / (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1520+ // /
1521+ // / rankDiff = rank(dest) - rank(mask).
1522+ // /
1523+ // / This method takes a conservative view: it may return false even if the mask
1524+ // / is technically foldable.
1525+ // /
1526+ // / EXAMPLE 1 (trivially foldable):
1527+ // / %c0 = arith.constant 0 : index
1528+ // / vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1529+ // / {in_bounds = [true, true]}
1530+ // / : vector<5x1xi32>, tensor<5x1xi32>
1531+ // /
1532+ // / EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
1533+ // / %c0 = arith.constant 0 : index
1534+ // / vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1535+ // / {in_bounds = [true, true]}
1536+ // / : vector<8x1xi32>, tensor<5x1xi32>
1537+ // /
1538+ // / TODO: Re-use in createReadOrMaskedRead
1539+ static bool isMaskTriviallyFoldable (SmallVector<OpFoldResult> &maskSizes,
1540+ SmallVector<Value> &writeIdxs,
1541+ ArrayRef<int64_t > destShape,
1542+ ArrayRef<int64_t > maskShape) {
1543+ // Masking is unavoidable in the case of dynamic tensors.
1544+ if (ShapedType::isDynamicShape (destShape))
1545+ return false ;
1546+
1547+ // Collect all constant mask sizes.
1548+ SmallVector<int64_t , 4 > cstMaskSizes;
1549+ for (auto [i, dimSize] : llvm::enumerate (maskSizes)) {
1550+ if (auto intSize = getConstantIntValue (dimSize)) {
1551+ cstMaskSizes.push_back (*intSize);
1552+ }
1553+ }
1554+
1555+ // If any of the mask sizes is non-constant, bail out.
1556+ if (cstMaskSizes.size () != maskShape.size ())
1557+ return false ;
1558+
1559+ // Collect all constant write indices.
1560+ SmallVector<int64_t , 4 > cstWriteIdxs;
1561+ for (auto [i, idx] : llvm::enumerate (writeIdxs)) {
1562+ APSInt intVal;
1563+ if (matchPattern (idx, m_ConstantInt (&intVal))) {
1564+ cstWriteIdxs.push_back (intVal.getSExtValue ());
1565+ }
1566+ }
1567+
1568+ // If any of the write indices is non-constant, bail out.
1569+ if (cstWriteIdxs.size () != destShape.size ())
1570+ return false ;
1571+
1572+ // Go over all destination dims and check (1) and (2). Take into account that:
1573+ // * The number of mask sizes will match the rank of the vector to store.
1574+ // This could be lower than the rank of the destination tensor.
1575+ // * Mask sizes could be larger than the corresponding mask shape (hence
1576+ // `clamp`).
1577+ // TODO: The 2nd item should be rejected by the verifier.
1578+ int64_t rankDiff = destShape.size () - cstMaskSizes.size ();
1579+ for (auto [i, idx] : llvm::enumerate (cstMaskSizes)) {
1580+ if (/* (1)*/ maskShape[i] > destShape[rankDiff + i] ||
1581+ /* (2)*/ destShape[rankDiff + i] <
1582+ (std::clamp (cstMaskSizes[i], int64_t (0 ), maskShape[i]) +
1583+ cstWriteIdxs[i]))
1584+ return false ;
1585+ }
1586+
1587+ return true ;
1588+ }
1589+
15091590// / Creates an optionally masked TransferWriteOp
15101591// /
15111592// / Generates the following operation:
15121593// / %res = vector.transfer_write %vectorToStore into %dest
15131594// /
1514- // / If the leading N dimensions of the destination tensor do not match
1595+ // / If the leading N dimensions of the vector to store do not match
15151596// / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
15161597// / masking is applied to ensure correctness:
15171598// /
1518- // / %mask = vector.create_mask(%destShape)
1599+ // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
15191600// / %res = vector.mask %mask {
15201601// / vector.transfer_write %vectorToStore into %dest
15211602// / }
15221603// /
1604+ // / The mask shape is identical to `vectorToStore` (with the element type ==
1605+ // / i1), and the mask values are based on the shape of the `dest` tensor.
1606+ // /
15231607// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15241608// / is used instead of masking:
15251609// /
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15281612// / %res = vector.transfer_write %input into %dest
15291613// / {in_bounds = in_bounds_flags}
15301614// /
1531- // / NOTE: All write offsets are set to 0.
1532- // / TODO: Allow specyfying write offsets .
1533- // / NOTE: When N < rank(input), the missing vector sizes are effectively
1534- // / extracted from the trailing sizes of `destSizes`. This means those sizes
1535- // / must be static .
1536- // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537- // / specifying all the vector sizes .
1615+ // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1616+ // / to 0 .
1617+ // /
1618+ // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619+ // / `valueToStore` .
1620+ // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621+ // / already provided in `vectorToStore` .
15381622static Operation *
15391623createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
15401624 Value dest,
15411625 ArrayRef<int64_t > inputVecSizesForLeadingDims,
1626+ SmallVector<Value> writeIndices = {},
15421627 bool useInBoundsInsteadOfMasking = false ) {
15431628
15441629 ShapedType destType = cast<ShapedType>(dest.getType ());
1545- assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546- static_cast <int64_t >(destType.getRank ()) &&
1547- " Rank mismatch!" );
1548- (void )destType;
1630+ int64_t destRank = destType.getRank ();
1631+ auto destShape = destType.getShape ();
15491632
1550- int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1551- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1633+ VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType ());
1634+ int64_t vecToStoreRank = vecToStoreType.getRank ();
1635+ auto vecToStoreShape = vecToStoreType.getShape ();
15521636
15531637 // Compute the in_bounds attribute
1554- SmallVector<bool > inBoundsVal (rank , true );
1638+ SmallVector<bool > inBoundsVal (vecToStoreRank , true );
15551639 if (useInBoundsInsteadOfMasking) {
15561640 // In this case, assume that all the required vector sizes have been
15571641 // provided.
15581642 assert (inputVecSizesForLeadingDims.size () ==
1559- static_cast <size_t >(destType .getRank ()) &&
1643+ static_cast <size_t >(vecToStoreType .getRank ()) &&
15601644 " Insufficient number of input vector sizes!" );
15611645 // Update the inBounds attribute.
1562- for (unsigned i = 0 ; i < rank ; i++)
1646+ for (unsigned i = 0 ; i < destRank ; i++)
15631647 inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15641648 !ShapedType::isDynamic (destShape[i]);
15651649 }
15661650
1651+ // If missing, initialize the write indices to 0.
1652+ assert (writeIndices.empty () ||
1653+ writeIndices.size () == static_cast <size_t >(destRank) &&
1654+ " Invalid number of write indices!" );
1655+ if (writeIndices.empty ()) {
1656+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1657+ writeIndices = SmallVector<Value>(destRank, zero);
1658+ }
1659+
15671660 // Generate the xfer_write Op
1568- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1569- Operation *write = builder.create <vector::TransferWriteOp>(
1570- loc,
1571- /* vector=*/ vectorToStore,
1572- /* source=*/ dest,
1573- /* indices=*/ SmallVector<Value>(rank, zero),
1574- /* inBounds=*/ inBoundsVal);
1575- assert (llvm::none_of (
1576- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1577- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1578- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1661+ Operation *write =
1662+ builder.create <vector::TransferWriteOp>(loc,
1663+ /* vector=*/ vectorToStore,
1664+ /* source=*/ dest,
1665+ /* indices=*/ writeIndices,
1666+ /* inBounds=*/ inBoundsVal);
15791667
15801668 // If masking is disabled, exit.
15811669 if (useInBoundsInsteadOfMasking)
15821670 return write;
15831671
1672+ assert (llvm::none_of (
1673+ destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1674+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1675+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1676+
15841677 // Check if masking is needed.
15851678 bool needMaskForWrite =
15861679 !llvm::equal (inputVecSizesForLeadingDims,
1587- destShape.take_front (inputVecSizesForLeadingDims.size ()));
1680+ destShape.take_front (destRank - vecToStoreRank +
1681+ inputVecSizesForLeadingDims.size ()));
15881682
15891683 // If masking is needed, generate the mask and mask the operation.
15901684 if (needMaskForWrite) {
1685+ // Get the mask shape + type. Missing mask dimensions are taken from
1686+ // `vectorToStore`.
15911687 SmallVector<int64_t > writeMaskShape;
15921688 writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
15931689 inputVecSizesForLeadingDims.end ());
1594- writeMaskShape.append (destShape.begin () +
1595- inputVecSizesForLeadingDims.size (),
1596- destShape.end ());
1690+ if (vecToStoreRank >
1691+ static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1692+ writeMaskShape.append (vecToStoreShape.begin () +
1693+ inputVecSizesForLeadingDims.size (),
1694+ vecToStoreShape.end ());
15971695 auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1598- Value maskForWrite = builder.create <vector::CreateMaskOp>(
1599- loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest));
1696+
1697+ SmallVector<OpFoldResult> destSizes =
1698+ tensor::getMixedSizes (builder, loc, dest);
1699+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1700+ destSizes.end ());
1701+
1702+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1703+ writeMaskShape))
1704+ return write;
1705+
1706+ Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1707+ loc, writeMaskType, maskSizes);
16001708 write = mlir::vector::maskOperation (builder, write, maskForWrite);
16011709 }
16021710
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17001808 Value dest = rewriter.create <tensor::EmptyOp>(
17011809 loc, reifiedReturnShapes[0 ],
17021810 transposeOp.getResult ().getType ().getElementType ());
1703- Operation *write =
1704- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1705- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1706- /* useInBoundsInsteadOfMasking=*/ false );
1811+ Operation *write = createWriteOrMaskedWrite (
1812+ rewriter, loc, transposeOp.getResult (), dest,
1813+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, /* writeIndices= */ {} ,
1814+ /* useInBoundsInsteadOfMasking=*/ false );
17071815 newResults.push_back (write->getResult (0 ));
17081816 return success ();
17091817}
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18391947 Value dest = rewriter.create <tensor::EmptyOp>(
18401948 loc, reifiedRetShapes[0 ],
18411949 shapeCastOp.getResult ().getType ().getElementType ());
1842- Operation *write =
1843- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (), dest,
1844- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1845- useInBoundsInsteadOfMasking);
1950+ Operation *write = createWriteOrMaskedWrite (
1951+ rewriter, loc, shapeCastOp.getResult (), dest,
1952+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1953+ /* writeIndices= */ {}, useInBoundsInsteadOfMasking);
18461954 newResults.push_back (write->getResult (0 ));
18471955 return success ();
18481956}
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18741982 // Create Xfer write Op
18751983 Value dest = rewriter.create <tensor::EmptyOp>(
18761984 loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1877- Operation *write =
1878- createWriteOrMaskedWrite ( rewriter, loc, maskedRead, dest,
1879- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1880- /* useInBoundsInsteadOfMasking=*/ false );
1985+ Operation *write = createWriteOrMaskedWrite (
1986+ rewriter, loc, maskedRead, dest,
1987+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {} ,
1988+ /* useInBoundsInsteadOfMasking=*/ false );
18811989 newResults.push_back (write->getResult (0 ));
18821990 return success ();
18831991}
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29223030 auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
29233031
29243032 // 3. Generate TransferReadOp + TransferWriteOp
2925- ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926- Value maskOp;
2927-
2928- // If vector sizes are user provided, make sure to mask. First, generate the
2929- // mask.
2930- if (!inputVectorSizes.empty ()) {
2931- auto *srcDefOp = source.getDefiningOp ();
2932- if (!srcDefOp) {
2933- LDBG (" Unable to get the defining Op of " << sliceOp);
2934- return failure ();
2935- }
2936-
2937- LogicalResult status =
2938- cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2939- rewriter, reifiedSrcSizes);
2940- if (status.failed ()) {
2941- LDBG (" Unable to reify result shapes of " << srcDefOp);
2942- return failure ();
2943- }
2944-
2945- // Create the mask
2946- auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2947- maskOp = rewriter.create <vector::CreateMaskOp>(
2948- sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2949- }
3033+ auto loc = sliceOp.getLoc ();
29503034
3035+ // Create read
29513036 SmallVector<Value> readIndices (
2952- vecType.getRank (),
2953- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2954- Operation *read = rewriter.create <vector::TransferReadOp>(
2955- sliceOp.getLoc (), vecType, source, readIndices, padValue,
2956- ArrayRef<bool >{readInBounds});
2957-
2958- if (maskOp) {
2959- read = mlir::vector::maskOperation (rewriter, read, maskOp);
2960- }
2961-
2962- auto writeIndices = getValueOrCreateConstantIndexOp (
2963- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2964-
2965- Operation *write = rewriter.create <vector::TransferWriteOp>(
2966- sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2967- ArrayRef<bool >{writeInBounds});
2968-
2969- if (maskOp) {
2970- write = mlir::vector::maskOperation (rewriter, write, maskOp);
2971- }
3037+ vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
3038+ Value read = mlir::vector::createReadOrMaskedRead (
3039+ rewriter, loc, source, vecType.getShape (), padValue);
3040+
3041+ // Create write
3042+ auto writeIndices =
3043+ getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3044+ Operation *write = createWriteOrMaskedWrite (
3045+ rewriter, loc, read, sliceOp.getDest (), vecType.getShape (), writeIndices);
29723046
29733047 // 4. Finalize
29743048 newResults.push_back (write->getResult (0 ));
0 commit comments