Skip to content

Commit 075a71d

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent 33fae27 commit 075a71d

File tree

4 files changed

+186
-53
lines changed

4 files changed

+186
-53
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805
inputShape[innerDimsPos[idx]] *= size;
18061806
auto maskedRead = vector::createReadOrMaskedRead(
18071807
rewriter, loc, packOp.getSource(), inputShape, padValue,
1808-
useInBoundsInsteadOfMasking);
1808+
useInBoundsInsteadOfMasking,
1809+
/*inputScalableVecSizes=*/{});
18091810

18101811
// Create ShapeCastOp.
18111812
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1840,6 +1841,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18401841
///
18411842
/// When collapsing scalable flags, conservatively avoids cases with two
18421843
/// scalable dims. We could re-visit this in the future.
1844+
///
1845+
/// If the vector sizes are not provided:
1846+
/// * the vector sizes are determined by the input operand and attributes,
1847+
/// * update the inBounds attribute instead of masking.
18431848
static VectorType getCollapsedVecType(VectorType type,
18441849
ArrayRef<AffineMap> reassociation) {
18451850
assert(type.getNumScalableDims() < 2 &&
@@ -1878,11 +1883,19 @@ static VectorType getCollapsedVecType(VectorType type,
18781883
/// vector::TransferWriteOp. - Write the result vector back to the destination
18791884
/// tensor.
18801885
/// If the vector sizes are not provided:
1881-
/// * the vector sizes are determined by the input operand and attributes,
1882-
/// * update the inBounds attribute instead of masking.
1886+
/// Vectorize `linalg.unpack %src into %dest` as:
1887+
/// // Reads a vector from the source tensor
1888+
/// %read = vector.transfer_read %src
1889+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1890+
/// %tr = vector.transpose %read
1891+
/// // Reshape the data based on the target
1892+
/// %sc = vector.shape_cast %tr
1893+
/// // Write the result vector to the destination tensor.
1894+
/// vector.transfer_write %sc into %dest
18831895
static LogicalResult
18841896
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18851897
ArrayRef<int64_t> inputVectorSizes,
1898+
ArrayRef<bool> inputScalableVecDims,
18861899
SmallVectorImpl<Value> &newResults) {
18871900

18881901
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1899,25 +1912,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18991912

19001913
auto destSize = unpackOp.getDestRank();
19011914

1902-
if (!inputVectorSizes.empty())
1903-
assert(inputVectorSizes.size() == destSize &&
1915+
if (!inputVectorSizes.empty()) {
1916+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
19041917
"Incorrect number of input vector sizes");
1918+
}
1919+
1920+
SmallVector<bool> readScalableVectorFlags;
1921+
SmallVector<bool> writeScalableVectorFlags;
1922+
SmallVector<int64_t> readVectorSizes;
1923+
SmallVector<int64_t> writeVectorSizes;
19051924

1906-
// vectorSizes is the shape of the vector that will be used to do final
1925+
// Split input-vector-sizes into vector sizes for the read and write
1926+
// operations.
1927+
if (!inputVectorSizes.empty()) {
1928+
readVectorSizes.append(inputVectorSizes.begin(),
1929+
inputVectorSizes.begin() + sourceShape.size());
1930+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1931+
inputVectorSizes.end());
1932+
}
1933+
if (!inputScalableVecDims.empty()) {
1934+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1935+
inputScalableVecDims.begin() +
1936+
sourceShape.size());
1937+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1938+
sourceShape.size(),
1939+
inputScalableVecDims.end());
1940+
} else {
1941+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1942+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1943+
}
1944+
1945+
// writeVectorSizes is the shape of the vector that will be used to do final
19071946
// write on the destination tensor. It is set like this: Let's say the
19081947
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
19091948
// Thus:
1910-
// 1. vectorSizes = sourceShape.take_front(N)
1911-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1949+
// 1. writeVectorSizes = sourceShape.take_front(N)
1950+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
19121951
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
19131952
// innerTiles attribute value.
1914-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1915-
if (vectorSizes.empty()) {
1916-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1953+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1954+
if (writeVectorSizes.empty()) {
1955+
if (ShapedType::isDynamicShape(sourceShape))
1956+
return failure();
1957+
1958+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
19171959
if (!outerDimsPerm.empty())
1918-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1960+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
19191961
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1920-
vectorSizes[pos] *= innerTiles[i];
1962+
writeVectorSizes[pos] *= innerTiles[i];
19211963

19221964
useInBoundsInsteadOfMasking = true;
19231965
}
@@ -1941,17 +1983,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19411983
// After applying outer_dims_perm: [8, 16]
19421984
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19431985

1944-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1945-
1946-
for (auto [index, size] : enumerate(innerTiles)) {
1947-
readVectorSizes[innerDimPos[index]] =
1948-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1949-
}
1950-
if (!outerDimsPerm.empty()) {
1951-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1986+
if (readVectorSizes.empty()) {
1987+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1988+
// sizes. Note, this will only work when all sizes are static.
1989+
readVectorSizes = writeVectorSizes;
1990+
for (auto [index, size] : enumerate(innerTiles)) {
1991+
readVectorSizes[innerDimPos[index]] =
1992+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1993+
}
1994+
if (!outerDimsPerm.empty()) {
1995+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1996+
}
1997+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1998+
sourceShape.end());
19521999
}
1953-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1954-
sourceShape.end());
19552000

19562001
Location loc = unpackOp->getLoc();
19572002

@@ -1963,7 +2008,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19632008
// to shape of source, then a mask is necessary.
19642009
Value readResult = vector::createReadOrMaskedRead(
19652010
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1966-
/*useInBoundsInsteadOfMasking=*/false);
2011+
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19672012

19682013
PackingMetadata packMetadata;
19692014
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -2009,7 +2054,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20092054
assert(succeeded(status) && "failed to reify result shapes");
20102055
auto maskedRead = vector::createReadOrMaskedRead(
20112056
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2012-
/*useInBoundsInsteadOfMasking=*/false);
2057+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
20132058

20142059
// Create Xfer write Op
20152060
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2093,6 +2138,9 @@ static LogicalResult
20932138
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20942139
ArrayRef<int64_t> inputVectorSizes) {
20952140

2141+
// FIXME!!!
2142+
return success();
2143+
20962144
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20972145
return !getConstantIntValue(res).has_value();
20982146
})) {
@@ -2429,6 +2477,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24292477
LDBG() << "pad value is not constant: " << packOp;
24302478
return failure();
24312479
}
2480+
24322481
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24332482
bool satisfyEmptyCond = true;
24342483
if (inputVectorSizes.empty()) {
@@ -2507,12 +2556,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
25072556
if (numOfScalableDims == 0)
25082557
return success();
25092558

2559+
// TODO: Check the following!
25102560
auto linalgOp = dyn_cast<LinalgOp>(op);
25112561

2512-
// Cond 1: There's been no need for scalable vectorisation of
2513-
// non-linalg Ops so far
2514-
if (!linalgOp)
2515-
return failure();
2562+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2563+
// exception of UnpackOp for which there is a dedicated hook.
2564+
if (!linalgOp) {
2565+
return isa<linalg::UnPackOp>(op) ? success() : failure();
2566+
}
25162567

25172568
// Cond 2: There's been no need for more than 2 scalable dims so far
25182569
if (numOfScalableDims > 2)
@@ -2610,7 +2661,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26102661
isa<linalg::MatmulTransposeAOp>(op) ||
26112662
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26122663
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2613-
hasReductionIterator(linalgOp));
2664+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
26142665
}
26152666

26162667
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2743,7 +2794,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27432794
})
27442795
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27452796
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2746-
inputVectorSizes, results);
2797+
inputVectorSizes,
2798+
inputScalableVecDims, results);
27472799
})
27482800
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27492801
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3135,7 +3187,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31353187
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31363188
Value read = mlir::vector::createReadOrMaskedRead(
31373189
rewriter, loc, source, vecType.getShape(), padValue,
3138-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3190+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3191+
/*inputScalableVecSizes=*/{});
31393192

31403193
// Create write
31413194
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
279279
// Attempt to unroll until targetRank or the first scalable dimension (which
280280
// cannot be unrolled).
281281
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
282-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
283-
auto it = llvm::find(scalableDimsToUnroll, true);
284-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
282+
auto inputScalableVecDimsToUnroll =
283+
vType.getScalableDims().drop_back(targetRank);
284+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
285+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
285286
if (firstScalableDim == 0)
286287
return {};
287288
// All scalable dimensions should be removed now.
288-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
289-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
289+
inputScalableVecDimsToUnroll =
290+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
291+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
290292
"unexpected leading scalable dimension");
291293
// Create an unroll iterator for leading dimensions.
292294
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
319321
ArrayRef<int64_t> inputVectorSizes,
320322
Value padValue,
321323
bool useInBoundsInsteadOfMasking,
322-
ArrayRef<bool> scalableDims) {
324+
ArrayRef<bool> inputScalableVecDims) {
323325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
324326
"invalid input vector sizes");
325327
auto sourceShapedType = cast<ShapedType>(source.getType());
326328
auto sourceShape = sourceShapedType.getShape();
327329
assert(sourceShape.size() == inputVectorSizes.size() &&
328330
"expected same ranks.");
329-
auto vectorType =
330-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
356358
? memref::getMixedSizes(builder, loc, source)
357359
: tensor::getMixedSizes(builder, loc, source);
358360

359-
auto maskType =
360-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
361+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
362+
inputScalableVecDims);
361363
Value mask =
362364
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
363365
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)