@@ -1506,29 +1506,67 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506 return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
15071507}
15081508
1509- // / Given an input, the mixed destSizes, and the vector sizes for vectorization,
1510- // / create an empty destination tensor and create a TransferWriteOp from the
1511- // / input to the empty tensor. If the destination shape is not the same as the
1512- // / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1513- // / mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1514- // / inBounds attribute of the transfer write op instead of masking.
1515- static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1516- Value input,
1517- SmallVector<OpFoldResult> destSizes,
1518- ArrayRef<int64_t > inputVectorSizes,
1519- bool useInBoundsInsteadOfMasking) {
1509+ // / Creates a TransferWriteOp to write `input` into a newly initialized
1510+ // / output tensor.
1511+ // /
1512+ // / Given:
1513+ // / - an input vector to write,
1514+ // / - the mixed destination sizes for the output tensor,
1515+ // / - and the vector sizes used for vectorization (i.e., the leading N dims,
1516+ // / for some value of N),
1517+ // /
1518+ // / this function generates the following sequence of ops:
1519+ // /
1520+ // / %dest = tensor.empty(%destSizes)
1521+ // / %res = vector.transfer_write %input into %dest
1522+ // /
1523+ // / If the leading N dimensions of the destination tensor do not match
1524+ // / `inputVecSizesForLeadingDims` (where N =
1525+ // / rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526+ // / correctness:
1527+ // /
1528+ // / %dest = tensor.empty(%destSizes)
1529+ // / %write = vector.transfer_write %input into %dest
1530+ // / %mask = vector.create_mask(%destSizes)
1531+ // / %res = vector.mask %mask { %write }
1532+ // /
1533+ // / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1534+ // / is used instead of masking:
1535+ // /
1536+ // / %dest = tensor.empty(%destSizes)
1537+ // / in_bounds_flags = (...)
1538+ // / %res = vector.transfer_write %input into %dest
1539+ // / {in_bounds = in_bounds_flags}
1540+ // /
1541+ // / NOTE: all write offsets are set to 0.
1542+ // / NOTE: When N < rank(input), the missing vector sizes are effectively
1543+ // / extracted from the trailing sizes of `destSizes`. This means those sizes
1544+ // / must be static. Supporting dynamic sizes will require the user to specify
1545+ // / the remaining vector sizes. This is left as a TODO.
1546+ static Operation *
1547+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value input,
1548+ SmallVector<OpFoldResult> destSizes,
1549+ ArrayRef<int64_t > inputVecSizesForLeadingDims,
1550+ bool useInBoundsInsteadOfMasking = false ) {
15201551
15211552 auto inputType = cast<VectorType>(input.getType ());
1553+ assert (inputType.getRank () == static_cast <int64_t >(destSizes.size ()) &&
1554+ " Rank mismatch!" );
1555+
15221556 Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
15231557 inputType.getElementType ());
15241558 int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
15251559 auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
15261560 auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
15271561 SmallVector<bool > inBoundsVal (rank, true );
15281562 if (useInBoundsInsteadOfMasking) {
1563+ // In this case, assume that all the required vector sizes have been
1564+ // provided.
1565+ assert (inputVecSizesForLeadingDims.size () == destSizes.size () &&
1566+ " Insufficient number of input vector sizes!" );
15291567 // Update the inBounds attribute.
15301568 for (unsigned i = 0 ; i < rank; i++)
1531- inBoundsVal[i] = (destShape[i] == inputVectorSizes [i]) &&
1569+ inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims [i]) &&
15321570 !ShapedType::isDynamic (destShape[i]);
15331571 }
15341572 Operation *write = builder.create <vector::TransferWriteOp>(
@@ -1538,17 +1576,20 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15381576 /* indices=*/ SmallVector<Value>(rank, zero),
15391577 /* inBounds=*/ inBoundsVal);
15401578 assert (llvm::none_of (
1541- destShape.drop_front (inputVectorSizes .size ()),
1579+ destShape.drop_front (inputVecSizesForLeadingDims .size ()),
15421580 [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1543- " Only dims aligned with inputVectorSizes may be dynamic" );
1581+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
15441582 if (useInBoundsInsteadOfMasking)
15451583 return write;
1546- bool needMaskForWrite = !llvm::equal (
1547- inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1584+ bool needMaskForWrite =
1585+ !llvm::equal (inputVecSizesForLeadingDims,
1586+ destShape.take_front (inputVecSizesForLeadingDims.size ()));
15481587 if (needMaskForWrite) {
15491588 SmallVector<int64_t > writeMaskShape;
1550- writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
1551- writeMaskShape.append (destShape.begin () + inputVectorSizes.size (),
1589+ writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1590+ inputVecSizesForLeadingDims.end ());
1591+ writeMaskShape.append (destShape.begin () +
1592+ inputVecSizesForLeadingDims.size (),
15521593 destShape.end ());
15531594 auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
15541595 Value maskForWrite =
@@ -1558,9 +1599,11 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15581599 return write;
15591600}
15601601
1561- // / Vectorize linalg::PackOp with (1) static innerTiles (2) constant
1602+ // / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
15621603// / padding value and (3) input vector sizes into:
1563- // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1604+ // /
1605+ // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1606+ // /
15641607// / As in the following example:
15651608// / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
15661609// / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
@@ -1582,8 +1625,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15821625// / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
15831626// /
15841627// / If the (3) input vector sizes are not provided, the vector sizes are
1585- // / determined by the result tensor shape. Also, we update the inBounds
1586- // / attribute instead of masking.
1628+ // / determined by the result tensor shape and the `in_bounds`
1629+ // / attribute is used instead of masking to mark out-of-bounds accesses.
1630+ // /
1631+ // / NOTE: The input vector sizes specify the dimensions corresponding to the
1632+ // / outer dimensions of the output tensor. The remaining dimensions are
1633+ // / computed based on, e.g., the static inner tiles.
1634+ // / Supporting dynamic inner tiles will require the user to specify the
1635+ // / missing vector sizes. This is left as a TODO.
15871636static LogicalResult
15881637vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
15891638 ArrayRef<int64_t > inputVectorSizes,
@@ -1644,9 +1693,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
16441693 loc, shapeCastOp.getResult (), destPermutation);
16451694
16461695 // Create TransferWriteOp.
1647- Operation *write = createWriteOrMaskedWrite (
1648- rewriter, loc, transposeOp.getResult (), reifiedReturnShapes[0 ],
1649- inputVectorSizes, /* useInBoundsInsteadOfMasking=*/ false );
1696+ Operation *write =
1697+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1698+ /* destSizes=*/ reifiedReturnShapes[0 ],
1699+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1700+ /* useInBoundsInsteadOfMasking=*/ false );
16501701 newResults.push_back (write->getResult (0 ));
16511702 return success ();
16521703}
@@ -1780,8 +1831,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
17801831 ? vectorSizes
17811832 : shapeCastOp.getResultVectorType ().getShape ());
17821833 Operation *write = createWriteOrMaskedWrite (
1783- rewriter, loc, shapeCastOp.getResult (), reifiedRetShapes[0 ],
1784- writeVectorSizes, useInBoundsInsteadOfMasking);
1834+ rewriter, loc, shapeCastOp.getResult (), /* destSizes=*/ reifiedRetShapes[0 ],
1835+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1836+ useInBoundsInsteadOfMasking);
17851837 newResults.push_back (write->getResult (0 ));
17861838 return success ();
17871839}
@@ -1810,7 +1862,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18101862 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
18111863 /* useInBoundsInsteadOfMasking=*/ false );
18121864 Operation *write = createWriteOrMaskedWrite (
1813- rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes,
1865+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ],
1866+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
18141867 /* useInBoundsInsteadOfMasking=*/ false );
18151868 newResults.push_back (write->getResult (0 ));
18161869 return success ();
0 commit comments