@@ -1506,20 +1506,120 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506 return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
15071507}
15081508
1509+ // / Determines whether a mask for xfer_write is trivially "all true"
1510+ // /
1511+ // / Given all the inputs required to generate a mask (mask sizes and shapes),
1512+ // / and an xfer_write operation (write indices and the destination tensor
1513+ // / shape), determines whether the corresponding mask would be trivially
1514+ // / foldable (i.e., trivially "all true").
1515+ // /
1516+ // / Use this method to avoid generating spurious masks and relaying on
1517+ // / vectorization post-processing to remove them.
1518+ // /
1519+ // / Pre-conditions for a mask to be trivially foldable:
1520+ // / * All involved shapes (mask + destination tensor) are static.
1521+ // / * All write indices are constant.
1522+ // / * All mask sizes are constant (including `arith.constant`).
1523+ // /
1524+ // / If the pre-conditions are met, the method checks for each destination
1525+ // / dimension `d`:
1526+ // / (1) destDimSize[rankDiff + d] <= maskShape[d]
1527+ // / (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1528+ // /
1529+ // / rankDiff = rank(dest) - rank(mask).
1530+ // /
1531+ // / This method takes a conservative view: it may return false even if the mask
1532+ // / is technically foldable.
1533+ // /
1534+ // / EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1535+ // / of the dest tensor):
1536+ // / %c0 = arith.constant 0 : index
1537+ // / %mask = vector.create_mask 5, 1
1538+ // / vector.mask %mask {
1539+ // / vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1540+ // / {in_bounds = [true, true]}
1541+ // / : vector<5x1xi32>, tensor<5x1xi32>
1542+ // / }
1543+ // /
1544+ // / EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1545+ // / mask is required to avoid out-of-bounds write):
1546+ // / %c0 = arith.constant 0 : index
1547+ // / %mask = vector.create_mask 5, 1
1548+ // / vector.mask %mask {
1549+ // / vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1550+ // / {in_bounds = [true, true]}
1551+ // / : vector<8x1xi32>, tensor<5x1xi32>
1552+ // / }
1553+ // /
1554+ // / TODO: Re-use in createReadOrMaskedRead
1555+ static bool isMaskTriviallyFoldable (SmallVector<OpFoldResult> &maskSizes,
1556+ SmallVector<Value> &writeIdxs,
1557+ ArrayRef<int64_t > destShape,
1558+ ArrayRef<int64_t > maskShape) {
1559+ // Masking is unavoidable in the case of dynamic tensors.
1560+ if (ShapedType::isDynamicShape (destShape))
1561+ return false ;
1562+
1563+ // Collect all constant mask sizes.
1564+ SmallVector<int64_t , 4 > cstMaskSizes;
1565+ for (auto [i, dimSize] : llvm::enumerate (maskSizes)) {
1566+ if (auto intSize = getConstantIntValue (dimSize)) {
1567+ cstMaskSizes.push_back (*intSize);
1568+ }
1569+ }
1570+
1571+ // If any of the mask sizes is non-constant, bail out.
1572+ if (cstMaskSizes.size () != maskShape.size ())
1573+ return false ;
1574+
1575+ // Collect all constant write indices.
1576+ SmallVector<int64_t , 4 > cstWriteIdxs;
1577+ for (auto [i, idx] : llvm::enumerate (writeIdxs)) {
1578+ APSInt intVal;
1579+ if (matchPattern (idx, m_ConstantInt (&intVal))) {
1580+ cstWriteIdxs.push_back (intVal.getSExtValue ());
1581+ }
1582+ }
1583+
1584+ // If any of the write indices is non-constant, bail out.
1585+ if (cstWriteIdxs.size () != destShape.size ())
1586+ return false ;
1587+
1588+ // Go over all destination dims and check (1) and (2). Take into account that:
1589+ // * The number of mask sizes will match the rank of the vector to store.
1590+ // This could be lower than the rank of the destination tensor.
1591+ // * Mask sizes could be larger than the corresponding mask shape (hence
1592+ // `clamp`).
1593+ // TODO: The 2nd item should be rejected by the verifier.
1594+ int64_t rankDiff = destShape.size () - cstMaskSizes.size ();
1595+ for (auto [i, idx] : llvm::enumerate (cstMaskSizes)) {
1596+ if (/* (1)*/ maskShape[i] > destShape[rankDiff + i] ||
1597+ /* (2)*/ destShape[rankDiff + i] <
1598+ (std::clamp (cstMaskSizes[i], int64_t (0 ), maskShape[i]) +
1599+ cstWriteIdxs[i]))
1600+ return false ;
1601+ }
1602+
1603+ return true ;
1604+ }
1605+
15091606// / Creates an optionally masked TransferWriteOp
15101607// /
15111608// / Generates the following operation:
15121609// / %res = vector.transfer_write %vectorToStore into %dest
15131610// /
1514- // / If the leading N dimensions of the destination tensor do not match
1611+ // / If the leading N dimensions of the vector to store do not match
15151612// / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
15161613// / masking is applied to ensure correctness:
15171614// /
1518- // / %mask = vector.create_mask(%destShape)
1615+ // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
15191616// / %res = vector.mask %mask {
15201617// / vector.transfer_write %vectorToStore into %dest
15211618// / }
15221619// /
1620+ // / The mask shape is identical to `vectorToStore` (with the element type ==
1621+ // / i1), and the mask values are based on the shape of the `dest` tensor.
1622+ // /
15231623// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15241624// / is used instead of masking:
15251625// /
@@ -1528,75 +1628,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15281628// / %res = vector.transfer_write %input into %dest
15291629// / {in_bounds = in_bounds_flags}
15301630// /
1531- // / NOTE: All write offsets are set to 0.
1532- // / TODO: Allow specyfying write offsets .
1533- // / NOTE: When N < rank(input), the missing vector sizes are effectively
1534- // / extracted from the trailing sizes of `destSizes`. This means those sizes
1535- // / must be static .
1536- // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537- // / specifying all the vector sizes .
1631+ // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1632+ // / to 0 .
1633+ // /
1634+ // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635+ // / `valueToStore` .
1636+ // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637+ // / already provided in `vectorToStore` .
15381638static Operation *
15391639createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
15401640 Value dest,
15411641 ArrayRef<int64_t > inputVecSizesForLeadingDims,
1642+ SmallVector<Value> writeIndices = {},
15421643 bool useInBoundsInsteadOfMasking = false ) {
15431644
15441645 ShapedType destType = cast<ShapedType>(dest.getType ());
1545- assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546- static_cast <int64_t >(destType.getRank ()) &&
1547- " Rank mismatch!" );
1548- (void )destType;
1646+ int64_t destRank = destType.getRank ();
1647+ auto destShape = destType.getShape ();
15491648
1550- int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1551- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1649+ VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType ());
1650+ int64_t vecToStoreRank = vecToStoreType.getRank ();
1651+ auto vecToStoreShape = vecToStoreType.getShape ();
15521652
15531653 // Compute the in_bounds attribute
1554- SmallVector<bool > inBoundsVal (rank , true );
1654+ SmallVector<bool > inBoundsVal (vecToStoreRank , true );
15551655 if (useInBoundsInsteadOfMasking) {
15561656 // In this case, assume that all the required vector sizes have been
15571657 // provided.
15581658 assert (inputVecSizesForLeadingDims.size () ==
1559- static_cast <size_t >(destType .getRank ()) &&
1659+ static_cast <size_t >(vecToStoreType .getRank ()) &&
15601660 " Insufficient number of input vector sizes!" );
15611661 // Update the inBounds attribute.
1562- for (unsigned i = 0 ; i < rank ; i++)
1662+ for (unsigned i = 0 ; i < destRank ; i++)
15631663 inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15641664 !ShapedType::isDynamic (destShape[i]);
15651665 }
15661666
1667+ // If missing, initialize the write indices to 0.
1668+ assert (writeIndices.empty () ||
1669+ writeIndices.size () == static_cast <size_t >(destRank) &&
1670+ " Invalid number of write indices!" );
1671+ if (writeIndices.empty ()) {
1672+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1673+ writeIndices = SmallVector<Value>(destRank, zero);
1674+ }
1675+
15671676 // Generate the xfer_write Op
1568- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1569- Operation *write = builder.create <vector::TransferWriteOp>(
1570- loc,
1571- /* vector=*/ vectorToStore,
1572- /* source=*/ dest,
1573- /* indices=*/ SmallVector<Value>(rank, zero),
1574- /* inBounds=*/ inBoundsVal);
1575- assert (llvm::none_of (
1576- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1577- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1578- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1677+ Operation *write =
1678+ builder.create <vector::TransferWriteOp>(loc,
1679+ /* vector=*/ vectorToStore,
1680+ /* source=*/ dest,
1681+ /* indices=*/ writeIndices,
1682+ /* inBounds=*/ inBoundsVal);
15791683
15801684 // If masking is disabled, exit.
15811685 if (useInBoundsInsteadOfMasking)
15821686 return write;
15831687
1688+ assert (llvm::none_of (
1689+ destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1690+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1691+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1692+
15841693 // Check if masking is needed.
15851694 bool needMaskForWrite =
15861695 !llvm::equal (inputVecSizesForLeadingDims,
1587- destShape.take_front (inputVecSizesForLeadingDims.size ()));
1696+ destShape.take_front (destRank - vecToStoreRank +
1697+ inputVecSizesForLeadingDims.size ()));
15881698
15891699 // If masking is needed, generate the mask and mask the operation.
15901700 if (needMaskForWrite) {
1701+ // Get the mask shape + type. Missing mask dimensions are taken from
1702+ // `vectorToStore`.
15911703 SmallVector<int64_t > writeMaskShape;
15921704 writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
15931705 inputVecSizesForLeadingDims.end ());
1594- writeMaskShape.append (destShape.begin () +
1595- inputVecSizesForLeadingDims.size (),
1596- destShape.end ());
1706+ if (vecToStoreRank >
1707+ static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1708+ writeMaskShape.append (vecToStoreShape.begin () +
1709+ inputVecSizesForLeadingDims.size (),
1710+ vecToStoreShape.end ());
15971711 auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1598- Value maskForWrite = builder.create <vector::CreateMaskOp>(
1599- loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest));
1712+
1713+ SmallVector<OpFoldResult> destSizes =
1714+ tensor::getMixedSizes (builder, loc, dest);
1715+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1716+ destSizes.end ());
1717+
1718+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1719+ writeMaskShape))
1720+ return write;
1721+
1722+ Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1723+ loc, writeMaskType, maskSizes);
16001724 write = mlir::vector::maskOperation (builder, write, maskForWrite);
16011725 }
16021726
@@ -1700,10 +1824,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17001824 Value dest = rewriter.create <tensor::EmptyOp>(
17011825 loc, reifiedReturnShapes[0 ],
17021826 transposeOp.getResult ().getType ().getElementType ());
1703- Operation *write =
1704- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1705- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1706- /* useInBoundsInsteadOfMasking=*/ false );
1827+ Operation *write = createWriteOrMaskedWrite (
1828+ rewriter, loc, transposeOp.getResult (), dest,
1829+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, /* writeIndices= */ {} ,
1830+ /* useInBoundsInsteadOfMasking=*/ false );
17071831 newResults.push_back (write->getResult (0 ));
17081832 return success ();
17091833}
@@ -1839,10 +1963,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18391963 Value dest = rewriter.create <tensor::EmptyOp>(
18401964 loc, reifiedRetShapes[0 ],
18411965 shapeCastOp.getResult ().getType ().getElementType ());
1842- Operation *write =
1843- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (), dest,
1844- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1845- useInBoundsInsteadOfMasking);
1966+ Operation *write = createWriteOrMaskedWrite (
1967+ rewriter, loc, shapeCastOp.getResult (), dest,
1968+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1969+ /* writeIndices= */ {}, useInBoundsInsteadOfMasking);
18461970 newResults.push_back (write->getResult (0 ));
18471971 return success ();
18481972}
@@ -1874,10 +1998,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18741998 // Create Xfer write Op
18751999 Value dest = rewriter.create <tensor::EmptyOp>(
18762000 loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1877- Operation *write =
1878- createWriteOrMaskedWrite ( rewriter, loc, maskedRead, dest,
1879- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1880- /* useInBoundsInsteadOfMasking=*/ false );
2001+ Operation *write = createWriteOrMaskedWrite (
2002+ rewriter, loc, maskedRead, dest,
2003+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {} ,
2004+ /* useInBoundsInsteadOfMasking=*/ false );
18812005 newResults.push_back (write->getResult (0 ));
18822006 return success ();
18832007}
@@ -2922,53 +3046,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29223046 auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
29233047
29243048 // 3. Generate TransferReadOp + TransferWriteOp
2925- ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926- Value maskOp;
2927-
2928- // If vector sizes are user provided, make sure to mask. First, generate the
2929- // mask.
2930- if (!inputVectorSizes.empty ()) {
2931- auto *srcDefOp = source.getDefiningOp ();
2932- if (!srcDefOp) {
2933- LDBG (" Unable to get the defining Op of " << sliceOp);
2934- return failure ();
2935- }
2936-
2937- LogicalResult status =
2938- cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2939- rewriter, reifiedSrcSizes);
2940- if (status.failed ()) {
2941- LDBG (" Unable to reify result shapes of " << srcDefOp);
2942- return failure ();
2943- }
2944-
2945- // Create the mask
2946- auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2947- maskOp = rewriter.create <vector::CreateMaskOp>(
2948- sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2949- }
3049+ auto loc = sliceOp.getLoc ();
29503050
3051+ // Create read
29513052 SmallVector<Value> readIndices (
2952- vecType.getRank (),
2953- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2954- Operation *read = rewriter.create <vector::TransferReadOp>(
2955- sliceOp.getLoc (), vecType, source, readIndices, padValue,
2956- ArrayRef<bool >{readInBounds});
2957-
2958- if (maskOp) {
2959- read = mlir::vector::maskOperation (rewriter, read, maskOp);
2960- }
2961-
2962- auto writeIndices = getValueOrCreateConstantIndexOp (
2963- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2964-
2965- Operation *write = rewriter.create <vector::TransferWriteOp>(
2966- sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2967- ArrayRef<bool >{writeInBounds});
2968-
2969- if (maskOp) {
2970- write = mlir::vector::maskOperation (rewriter, write, maskOp);
2971- }
3053+ vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
3054+ Value read = mlir::vector::createReadOrMaskedRead (
3055+ rewriter, loc, source, vecType.getShape (), padValue);
3056+
3057+ // Create write
3058+ auto writeIndices =
3059+ getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3060+ Operation *write = createWriteOrMaskedWrite (
3061+ rewriter, loc, read, sliceOp.getDest (), vecType.getShape (), writeIndices);
29723062
29733063 // 4. Finalize
29743064 newResults.push_back (write->getResult (0 ));
0 commit comments