@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
15901590// / Creates an optionally masked TransferWriteOp
15911591// /
15921592// / Generates the following operation:
1593- // / %res = vector.transfer_write %vectorToStore into %dest
1593+ // / %res = vector.transfer_write %vecToStore into %dest
15941594// /
1595- // / If the leading N dimensions of the vector to store do not match
1596- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1597- // / masking is applied to ensure correctness:
1595+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
15981596// /
1599- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1597+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
16001598// / %res = vector.mask %mask {
1601- // / vector.transfer_write %vectorToStore into %dest
1599+ // / vector.transfer_write %vecToStore into %dest
16021600// / }
16031601// /
1604- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1602+ // / The mask shape is identical to `vecToStore ` (with the element type ==
16051603// / i1), and the mask values are based on the shape of the `dest` tensor.
16061604// /
16071605// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16081606// / is used instead of masking:
16091607// /
1610- // / %write = vector.transfer_write %vectorToStore into %dest
1608+ // / %write = vector.transfer_write %vecToStore into %dest
16111609// / in_bounds_flags = (...)
16121610// / %res = vector.transfer_write %input into %dest
16131611// / {in_bounds = in_bounds_flags}
16141612// /
1615- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1616- // / to 0.
1617- // /
1618- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619- // / `valueToStore`.
1620- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621- // / already provided in `vectorToStore`.
1613+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1614+ // / are set to 0.
16221615static Operation *
1623- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1624- Value dest,
1625- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1626- SmallVector<Value> writeIndices = {},
1616+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1617+ Value dest, SmallVector<Value> writeIndices = {},
16271618 bool useInBoundsInsteadOfMasking = false ) {
16281619
16291620 ShapedType destType = cast<ShapedType>(dest.getType ());
16301621 int64_t destRank = destType.getRank ();
16311622 auto destShape = destType.getShape ();
16321623
1633- VectorType vecToStoreType = cast<VectorType>(vectorToStore .getType ());
1624+ VectorType vecToStoreType = cast<VectorType>(vecToStore .getType ());
16341625 int64_t vecToStoreRank = vecToStoreType.getRank ();
16351626 auto vecToStoreShape = vecToStoreType.getShape ();
16361627
16371628 // Compute the in_bounds attribute
16381629 SmallVector<bool > inBoundsVal (vecToStoreRank, true );
16391630 if (useInBoundsInsteadOfMasking) {
1640- // In this case, assume that all the required vector sizes have been
1641- // provided.
1642- assert (inputVecSizesForLeadingDims.size () ==
1643- static_cast <size_t >(vecToStoreType.getRank ()) &&
1644- " Insufficient number of input vector sizes!" );
1645- // Update the inBounds attribute.
16461631 for (unsigned i = 0 ; i < destRank; i++)
1647- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims [i]) &&
1632+ inBoundsVal[i] = (destShape[i] == vecToStoreShape [i]) &&
16481633 !ShapedType::isDynamic (destShape[i]);
16491634 }
16501635
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16601645 // Generate the xfer_write Op
16611646 Operation *write =
16621647 builder.create <vector::TransferWriteOp>(loc,
1663- /* vector=*/ vectorToStore ,
1648+ /* vector=*/ vecToStore ,
16641649 /* source=*/ dest,
16651650 /* indices=*/ writeIndices,
16661651 /* inBounds=*/ inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16691654 if (useInBoundsInsteadOfMasking)
16701655 return write;
16711656
1672- assert (llvm::none_of (
1673- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1674- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1675- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1676-
1677- // Check if masking is needed.
1678- bool needMaskForWrite =
1679- !llvm::equal (inputVecSizesForLeadingDims,
1680- destShape.take_front (destRank - vecToStoreRank +
1681- inputVecSizesForLeadingDims.size ()));
1682-
1683- // If masking is needed, generate the mask and mask the operation.
1684- if (needMaskForWrite) {
1685- // Get the mask shape + type. Missing mask dimensions are taken from
1686- // `vectorToStore`.
1687- SmallVector<int64_t > writeMaskShape;
1688- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1689- inputVecSizesForLeadingDims.end ());
1690- if (vecToStoreRank >
1691- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1692- writeMaskShape.append (vecToStoreShape.begin () +
1693- inputVecSizesForLeadingDims.size (),
1694- vecToStoreShape.end ());
1695- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1696-
1697- SmallVector<OpFoldResult> destSizes =
1698- tensor::getMixedSizes (builder, loc, dest);
1699- SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1700- destSizes.end ());
1701-
1702- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1703- writeMaskShape))
1704- return write;
1705-
1706- Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1707- loc, writeMaskType, maskSizes);
1708- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1709- }
1657+ // Check if masking is needed. If not, exit.
1658+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1659+ return write;
1660+
1661+ // Compute the mask and mask the write Op.
1662+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1663+
1664+ SmallVector<OpFoldResult> destSizes =
1665+ tensor::getMixedSizes (builder, loc, dest);
1666+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - vecToStoreRank,
1667+ destSizes.end ());
1668+
1669+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1670+ vecToStoreShape))
1671+ return write;
17101672
1711- return write;
1673+ Value maskForWrite =
1674+ builder.createOrFold <vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1675+ return mlir::vector::maskOperation (builder, write, maskForWrite);
17121676}
17131677
17141678// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18081772 Value dest = rewriter.create <tensor::EmptyOp>(
18091773 loc, reifiedReturnShapes[0 ],
18101774 transposeOp.getResult ().getType ().getElementType ());
1811- Operation *write = createWriteOrMaskedWrite (
1812- rewriter, loc, transposeOp.getResult (), dest,
1813- /* inputVecSizesForLeadingDims= */ inputVectorSizes, /* writeIndices=*/ {},
1814- /* useInBoundsInsteadOfMasking=*/ false );
1775+ Operation *write =
1776+ createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1777+ /* writeIndices=*/ {},
1778+ /* useInBoundsInsteadOfMasking=*/ false );
18151779 newResults.push_back (write->getResult (0 ));
18161780 return success ();
18171781}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19491913 shapeCastOp.getResult ().getType ().getElementType ());
19501914 Operation *write = createWriteOrMaskedWrite (
19511915 rewriter, loc, shapeCastOp.getResult (), dest,
1952- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
19531916 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
19541917 newResults.push_back (write->getResult (0 ));
19551918 return success ();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19821945 // Create Xfer write Op
19831946 Value dest = rewriter.create <tensor::EmptyOp>(
19841947 loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1985- Operation *write = createWriteOrMaskedWrite (
1986- rewriter, loc, maskedRead, dest,
1987- /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {},
1988- /* useInBoundsInsteadOfMasking=*/ false );
1948+ Operation *write =
1949+ createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest, {},
1950+ /* useInBoundsInsteadOfMasking=*/ false );
19891951 newResults.push_back (write->getResult (0 ));
19901952 return success ();
19911953}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30413003 // Create write
30423004 auto writeIndices =
30433005 getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3044- Operation *write = createWriteOrMaskedWrite (
3045- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices);
3006+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, read,
3007+ sliceOp.getDest (), writeIndices);
30463008
30473009 // 4. Finalize
30483010 newResults.push_back (write->getResult (0 ));
0 commit comments