@@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
16061606// / Creates an optionally masked TransferWriteOp
16071607// /
16081608// / Generates the following operation:
1609- // / %res = vector.transfer_write %vectorToStore into %dest
1609+ // / %res = vector.transfer_write %vecToStore into %dest
16101610// /
1611- // / If the leading N dimensions of the vector to store do not match
1612- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1613- // / masking is applied to ensure correctness:
1611+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
16141612// /
1615- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1613+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
16161614// / %res = vector.mask %mask {
1617- // / vector.transfer_write %vectorToStore into %dest
1615+ // / vector.transfer_write %vecToStore into %dest
16181616// / }
16191617// /
1620- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1618+ // / The mask shape is identical to `vecToStore ` (with the element type ==
16211619// / i1), and the mask values are based on the shape of the `dest` tensor.
16221620// /
16231621// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16241622// / is used instead of masking:
16251623// /
1626- // / %write = vector.transfer_write %vectorToStore into %dest
1624+ // / %write = vector.transfer_write %vecToStore into %dest
16271625// / in_bounds_flags = (...)
16281626// / %res = vector.transfer_write %input into %dest
16291627// / {in_bounds = in_bounds_flags}
16301628// /
1631- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1632- // / to 0.
1633- // /
1634- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635- // / `valueToStore`.
1636- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637- // / already provided in `vectorToStore`.
1629+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1630+ // / are set to 0.
16381631static Operation *
1639- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1640- Value dest,
1641- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1642- SmallVector<Value> writeIndices = {},
1632+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1633+ Value dest, SmallVector<Value> writeIndices = {},
16431634 bool useInBoundsInsteadOfMasking = false ) {
16441635
16451636 ShapedType destType = cast<ShapedType>(dest.getType ());
16461637 int64_t destRank = destType.getRank ();
16471638 auto destShape = destType.getShape ();
16481639
1649- VectorType vecToStoreType = cast<VectorType>(vectorToStore .getType ());
1640+ VectorType vecToStoreType = cast<VectorType>(vecToStore .getType ());
16501641 int64_t vecToStoreRank = vecToStoreType.getRank ();
16511642 auto vecToStoreShape = vecToStoreType.getShape ();
16521643
16531644 // Compute the in_bounds attribute
16541645 SmallVector<bool > inBoundsVal (vecToStoreRank, true );
16551646 if (useInBoundsInsteadOfMasking) {
1656- // In this case, assume that all the required vector sizes have been
1657- // provided.
1658- assert (inputVecSizesForLeadingDims.size () ==
1659- static_cast <size_t >(vecToStoreType.getRank ()) &&
1660- " Insufficient number of input vector sizes!" );
16611647 // Update the inBounds attribute.
16621648 // FIXME: This computation is too weak - it ignores the write indices.
16631649 for (unsigned i = 0 ; i < vecToStoreRank; i++)
16641650 inBoundsVal[i] =
1665- (destShape[i] >= inputVecSizesForLeadingDims [i]) &&
1651+ (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape [i]) &&
16661652 !ShapedType::isDynamic (destShape[destRank - vecToStoreRank + i]);
16671653 }
16681654
@@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16781664 // Generate the xfer_write Op
16791665 Operation *write =
16801666 builder.create <vector::TransferWriteOp>(loc,
1681- /* vector=*/ vectorToStore ,
1667+ /* vector=*/ vecToStore ,
16821668 /* source=*/ dest,
16831669 /* indices=*/ writeIndices,
16841670 /* inBounds=*/ inBoundsVal);
@@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16871673 if (useInBoundsInsteadOfMasking)
16881674 return write;
16891675
1690- assert (llvm::none_of (
1691- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1692- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1693- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1694-
1695- // Check if masking is needed.
1696- bool needMaskForWrite =
1697- !llvm::equal (inputVecSizesForLeadingDims,
1698- destShape.take_front (destRank - vecToStoreRank +
1699- inputVecSizesForLeadingDims.size ()));
1700-
1701- // If masking is needed, generate the mask and mask the operation.
1702- if (needMaskForWrite) {
1703- // Get the mask shape + type. Missing mask dimensions are taken from
1704- // `vectorToStore`.
1705- SmallVector<int64_t > writeMaskShape;
1706- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1707- inputVecSizesForLeadingDims.end ());
1708- if (vecToStoreRank >
1709- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1710- writeMaskShape.append (vecToStoreShape.begin () +
1711- inputVecSizesForLeadingDims.size (),
1712- vecToStoreShape.end ());
1713- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1714-
1715- SmallVector<OpFoldResult> destSizes =
1716- tensor::getMixedSizes (builder, loc, dest);
1717- SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1718- destSizes.end ());
1719-
1720- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1721- writeMaskShape))
1722- return write;
1723-
1724- Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1725- loc, writeMaskType, maskSizes);
1726- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1727- }
1676+ // Check if masking is needed. If not, exit.
1677+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1678+ return write;
1679+
1680+ // Compute the mask and mask the write Op.
1681+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1682+
1683+ SmallVector<OpFoldResult> destSizes =
1684+ tensor::getMixedSizes (builder, loc, dest);
1685+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - vecToStoreRank,
1686+ destSizes.end ());
1687+
1688+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1689+ vecToStoreShape))
1690+ return write;
17281691
1729- return write;
1692+ Value maskForWrite =
1693+ builder.createOrFold <vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1694+ return mlir::vector::maskOperation (builder, write, maskForWrite);
17301695}
17311696
17321697// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18261791 Value dest = rewriter.create <tensor::EmptyOp>(
18271792 loc, reifiedReturnShapes[0 ],
18281793 transposeOp.getResult ().getType ().getElementType ());
1829- Operation *write = createWriteOrMaskedWrite (
1830- rewriter, loc, transposeOp.getResult (), dest,
1831- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1794+ Operation *write =
1795+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest);
18321796 newResults.push_back (write->getResult (0 ));
18331797 return success ();
18341798}
@@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19661930 shapeCastOp.getResult ().getType ().getElementType ());
19671931 Operation *write = createWriteOrMaskedWrite (
19681932 rewriter, loc, shapeCastOp.getResult (), dest,
1969- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
19701933 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
19711934 newResults.push_back (write->getResult (0 ));
19721935 return success ();
@@ -1999,9 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19991962 // Create Xfer write Op
20001963 Value dest = rewriter.create <tensor::EmptyOp>(
20011964 loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
2002- Operation *write = createWriteOrMaskedWrite (
2003- rewriter, loc, maskedRead, dest,
2004- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1965+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest);
20051966 newResults.push_back (write->getResult (0 ));
20061967 return success ();
20071968}
@@ -3043,9 +3004,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30433004 // Create write
30443005 auto writeIndices =
30453006 getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3046- Operation *write = createWriteOrMaskedWrite (
3047- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices ,
3048- /* useInBoundsInsteadOfMasking= */ inputVectorSizes.empty ());
3007+ Operation *write =
3008+ createWriteOrMaskedWrite ( rewriter, loc, read, sliceOp.getDest (),
3009+ writeIndices, inputVectorSizes.empty ());
30493010
30503011 // 4. Finalize
30513012 newResults.push_back (write->getResult (0 ));
0 commit comments