Skip to content

Commit f215a61

Browse files
authored
[mlir][linalg][vector] Refine create{Read|Write}OrMasked{Read|Write} (nfc) (llvm#135350)
The semantics of `createReadOrMaskedRead` and `createWriteOrMaskedWrite` are currently a bit inconsistent and not fully documented: * The input vector sizes are passed as `readShape` and `inputVectorSizes`, respectively — inconsistent naming. * Currently, the input vector sizes in `createWriteOrMaskedWrite` are not required to be complete: any missing trailing sizes are inferred from the destination tensor. This only works when the destination tensor is statically shaped. * Unlike `createReadOrMaskedRead`, the documentation for `createWriteOrMaskedWrite` does not specify that write offsets are hard-coded to 0. This PR only updates the documentation and unifies the naming. As such, it is NFC. A follow-up PR will generalize and unify the implementation to support, for example, dynamically shaped destination tensors — a requirement for enabling scalable vectorization of `linalg.pack` and `linalg.unpack`.
1 parent 4780658 commit f215a61

File tree

3 files changed

+103
-51
lines changed

3 files changed

+103
-51
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,18 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
211211
/// are not linearizable.
212212
bool isLinearizableVector(VectorType type);
213213

214-
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
215-
/// vector type for the read is not the same as the type of `source`, then a
216-
/// mask is created on the read, if use of mask is specified or the bounds on a
217-
/// dimension are different.
218-
///
219-
/// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set
220-
/// properly, based on
221-
/// the rank dimensions of the source and destination tensors. And that is
222-
/// what determines if masking is done.
223-
///
224-
/// Note that the internal `vector::TransferReadOp` always read at indices zero
225-
/// for each dimension of the passed in tensor.
214+
/// Creates a TransferReadOp from `source`.
215+
///
216+
/// The shape of the vector to read is specified via `inputVectorSizes`. If the
217+
/// shape of the output vector differs from the shape of the value being read,
218+
/// masking is used to avoid out-of-bounds accesses. Set
219+
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
220+
/// instead of explicit masks.
221+
///
222+
/// Note: all read offsets are set to 0.
226223
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
227-
ArrayRef<int64_t> readShape, Value padValue,
228-
bool useInBoundsInsteadOfMasking);
224+
ArrayRef<int64_t> inputVectorSizes, Value padValue,
225+
bool useInBoundsInsteadOfMasking = false);
229226

230227
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
231228
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,29 +1506,67 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
15071507
}
15081508

1509-
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1510-
/// create an empty destination tensor and create a TransferWriteOp from the
1511-
/// input to the empty tensor. If the destination shape is not the same as the
1512-
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1513-
/// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1514-
/// inBounds attribute of the transfer write op instead of masking.
1515-
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1516-
Value input,
1517-
SmallVector<OpFoldResult> destSizes,
1518-
ArrayRef<int64_t> inputVectorSizes,
1519-
bool useInBoundsInsteadOfMasking) {
1509+
/// Creates a TransferWriteOp to write `input` into a newly initialized
1510+
/// output tensor.
1511+
///
1512+
/// Given:
1513+
/// - an input vector to write,
1514+
/// - the mixed destination sizes for the output tensor,
1515+
/// - and the vector sizes used for vectorization (i.e., the leading N dims,
1516+
/// for some value of N),
1517+
///
1518+
/// this function generates the following sequence of ops:
1519+
///
1520+
/// %dest = tensor.empty(%destSizes)
1521+
/// %res = vector.transfer_write %input into %dest
1522+
///
1523+
/// If the leading N dimensions of the destination tensor do not match
1524+
/// `inputVecSizesForLeadingDims` (where N =
1525+
/// rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526+
/// correctness:
1527+
///
1528+
/// %dest = tensor.empty(%destSizes)
1529+
/// %write = vector.transfer_write %input into %dest
1530+
/// %mask = vector.create_mask(%destSizes)
1531+
/// %res = vector.mask %mask { %write }
1532+
///
1533+
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1534+
/// is used instead of masking:
1535+
///
1536+
/// %dest = tensor.empty(%destSizes)
1537+
/// in_bounds_flags = (...)
1538+
/// %res = vector.transfer_write %input into %dest
1539+
/// {in_bounds = in_bounds_flags}
1540+
///
1541+
/// NOTE: all write offsets are set to 0.
1542+
/// NOTE: When N < rank(input), the missing vector sizes are effectively
1543+
/// extracted from the trailing sizes of `destSizes`. This means those sizes
1544+
/// must be static. Supporting dynamic sizes will require the user to specify
1545+
/// the remaining vector sizes. This is left as a TODO.
1546+
static Operation *
1547+
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
1548+
SmallVector<OpFoldResult> destSizes,
1549+
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1550+
bool useInBoundsInsteadOfMasking = false) {
15201551

15211552
auto inputType = cast<VectorType>(input.getType());
1553+
assert(inputType.getRank() == static_cast<int64_t>(destSizes.size()) &&
1554+
"Rank mismatch!");
1555+
15221556
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
15231557
inputType.getElementType());
15241558
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
15251559
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
15261560
auto destShape = cast<ShapedType>(dest.getType()).getShape();
15271561
SmallVector<bool> inBoundsVal(rank, true);
15281562
if (useInBoundsInsteadOfMasking) {
1563+
// In this case, assume that all the required vector sizes have been
1564+
// provided.
1565+
assert(inputVecSizesForLeadingDims.size() == destSizes.size() &&
1566+
"Insufficient number of input vector sizes!");
15291567
// Update the inBounds attribute.
15301568
for (unsigned i = 0; i < rank; i++)
1531-
inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1569+
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15321570
!ShapedType::isDynamic(destShape[i]);
15331571
}
15341572
Operation *write = builder.create<vector::TransferWriteOp>(
@@ -1538,17 +1576,20 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15381576
/*indices=*/SmallVector<Value>(rank, zero),
15391577
/*inBounds=*/inBoundsVal);
15401578
assert(llvm::none_of(
1541-
destShape.drop_front(inputVectorSizes.size()),
1579+
destShape.drop_front(inputVecSizesForLeadingDims.size()),
15421580
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1543-
"Only dims aligned with inputVectorSizes may be dynamic");
1581+
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
15441582
if (useInBoundsInsteadOfMasking)
15451583
return write;
1546-
bool needMaskForWrite = !llvm::equal(
1547-
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1584+
bool needMaskForWrite =
1585+
!llvm::equal(inputVecSizesForLeadingDims,
1586+
destShape.take_front(inputVecSizesForLeadingDims.size()));
15481587
if (needMaskForWrite) {
15491588
SmallVector<int64_t> writeMaskShape;
1550-
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1551-
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1589+
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1590+
inputVecSizesForLeadingDims.end());
1591+
writeMaskShape.append(destShape.begin() +
1592+
inputVecSizesForLeadingDims.size(),
15521593
destShape.end());
15531594
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
15541595
Value maskForWrite =
@@ -1558,9 +1599,11 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15581599
return write;
15591600
}
15601601

1561-
/// Vectorize linalg::PackOp with (1) static innerTiles (2) constant
1602+
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
15621603
/// padding value and (3) input vector sizes into:
1563-
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1604+
///
1605+
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1606+
///
15641607
/// As in the following example:
15651608
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
15661609
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
@@ -1582,8 +1625,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
15821625
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
15831626
///
15841627
/// If the (3) input vector sizes are not provided, the vector sizes are
1585-
/// determined by the result tensor shape. Also, we update the inBounds
1586-
/// attribute instead of masking.
1628+
/// determined by the result tensor shape and the `in_bounds`
1629+
/// attribute is used instead of masking to mark out-of-bounds accesses.
1630+
///
1631+
/// NOTE: The input vector sizes specify the dimensions corresponding to the
1632+
/// outer dimensions of the output tensor. The remaining dimensions are
1633+
/// computed based on, e.g., the static inner tiles.
1634+
/// Supporting dynamic inner tiles will require the user to specify the
1635+
/// missing vector sizes. This is left as a TODO.
15871636
static LogicalResult
15881637
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
15891638
ArrayRef<int64_t> inputVectorSizes,
@@ -1644,9 +1693,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
16441693
loc, shapeCastOp.getResult(), destPermutation);
16451694

16461695
// Create TransferWriteOp.
1647-
Operation *write = createWriteOrMaskedWrite(
1648-
rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1649-
inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
1696+
Operation *write =
1697+
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1698+
/*destSizes=*/reifiedReturnShapes[0],
1699+
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1700+
/*useInBoundsInsteadOfMasking=*/false);
16501701
newResults.push_back(write->getResult(0));
16511702
return success();
16521703
}
@@ -1780,8 +1831,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
17801831
? vectorSizes
17811832
: shapeCastOp.getResultVectorType().getShape());
17821833
Operation *write = createWriteOrMaskedWrite(
1783-
rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1784-
writeVectorSizes, useInBoundsInsteadOfMasking);
1834+
rewriter, loc, shapeCastOp.getResult(), /*destSizes=*/reifiedRetShapes[0],
1835+
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1836+
useInBoundsInsteadOfMasking);
17851837
newResults.push_back(write->getResult(0));
17861838
return success();
17871839
}
@@ -1810,7 +1862,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18101862
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
18111863
/*useInBoundsInsteadOfMasking=*/false);
18121864
Operation *write = createWriteOrMaskedWrite(
1813-
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1865+
rewriter, loc, maskedRead, reifiedReturnShapes[0],
1866+
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
18141867
/*useInBoundsInsteadOfMasking=*/false);
18151868
newResults.push_back(write->getResult(0));
18161869
return success();

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -327,26 +327,28 @@ bool vector::isLinearizableVector(VectorType type) {
327327
}
328328

329329
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
330-
Value source, ArrayRef<int64_t> readShape,
330+
Value source,
331+
ArrayRef<int64_t> inputVectorSizes,
331332
Value padValue,
332333
bool useInBoundsInsteadOfMasking) {
333-
assert(llvm::none_of(readShape,
334+
assert(llvm::none_of(inputVectorSizes,
334335
[](int64_t s) { return s == ShapedType::kDynamic; }) &&
335-
"expected static shape");
336+
"invalid input vector sizes");
336337
auto sourceShapedType = cast<ShapedType>(source.getType());
337338
auto sourceShape = sourceShapedType.getShape();
338-
assert(sourceShape.size() == readShape.size() && "expected same ranks.");
339-
auto maskType = VectorType::get(readShape, builder.getI1Type());
340-
auto vectorType = VectorType::get(readShape, padValue.getType());
339+
assert(sourceShape.size() == inputVectorSizes.size() &&
340+
"expected same ranks.");
341+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
342+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
341343
assert(padValue.getType() == sourceShapedType.getElementType() &&
342344
"expected same pad element type to match source element type");
343-
int64_t readRank = readShape.size();
345+
int64_t readRank = inputVectorSizes.size();
344346
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
345347
SmallVector<bool> inBoundsVal(readRank, true);
346348
if (useInBoundsInsteadOfMasking) {
347349
// Update the inBounds attribute.
348350
for (unsigned i = 0; i < readRank; i++)
349-
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
351+
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
350352
!ShapedType::isDynamic(sourceShape[i]);
351353
}
352354
auto transferReadOp = builder.create<vector::TransferReadOp>(
@@ -357,7 +359,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
357359
/*padding=*/padValue,
358360
/*inBounds=*/inBoundsVal);
359361

360-
if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking)
362+
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
361363
return transferReadOp;
362364
SmallVector<OpFoldResult> mixedSourceDims =
363365
tensor::getMixedSizes(builder, loc, source);

0 commit comments

Comments
 (0)