Skip to content

Commit 6ba0d17

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent ac4c13d commit 6ba0d17

File tree

4 files changed

+190
-62
lines changed

4 files changed

+190
-62
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805
inputShape[innerDimsPos[idx]] *= size;
18061806
auto maskedRead = vector::createReadOrMaskedRead(
18071807
rewriter, loc, packOp.getSource(), inputShape, padValue,
1808-
useInBoundsInsteadOfMasking);
1808+
useInBoundsInsteadOfMasking,
1809+
/*inputScalableVecSizes=*/{});
18091810

18101811
// Create ShapeCastOp.
18111812
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1831,18 +1832,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18311832
return success();
18321833
}
18331834

1834-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1835-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1836-
/// vector::TransposeOp - Transpose the Source tensor
1837-
/// ShapeCastOp - Reshape the data based on the target.
1838-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1839-
/// tensor.
1840-
/// If the vector sizes are not provided:
1835+
/// Vectorize `linalg.unpack %src into %dest` as:
1836+
/// // Reads a vector from the source tensor
1837+
/// %read = vector.transfer_read %src
1838+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1839+
/// %tr = vector.transpose %read
1840+
/// // Reshape the data based on the target
1841+
/// %sc = vector.shape_cast %tr
1842+
/// // Write the result vector to the destination tensor.
1843+
/// vector.transfer_write %sc into %dest
1844+
///
1845+
/// If the vector sizes are not provided:
18411846
/// * the vector sizes are determined by the input operand and attributes,
18421847
/// * update the inBounds attribute instead of masking.
18431848
static LogicalResult
18441849
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18451850
ArrayRef<int64_t> inputVectorSizes,
1851+
ArrayRef<bool> inputScalableVecDims,
18461852
SmallVectorImpl<Value> &newResults) {
18471853

18481854
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1859,25 +1865,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18591865

18601866
auto destSize = unpackOp.getDestRank();
18611867

1862-
if (!inputVectorSizes.empty())
1863-
assert(inputVectorSizes.size() == destSize &&
1868+
if (!inputVectorSizes.empty()) {
1869+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
18641870
"Incorrect number of input vector sizes");
1871+
}
1872+
1873+
SmallVector<bool> readScalableVectorFlags;
1874+
SmallVector<bool> writeScalableVectorFlags;
1875+
SmallVector<int64_t> readVectorSizes;
1876+
SmallVector<int64_t> writeVectorSizes;
18651877

1866-
// vectorSizes is the shape of the vector that will be used to do final
1878+
// Split input-vector-sizes into vector sizes for the read and write
1879+
// operations.
1880+
if (!inputVectorSizes.empty()) {
1881+
readVectorSizes.append(inputVectorSizes.begin(),
1882+
inputVectorSizes.begin() + sourceShape.size());
1883+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1884+
inputVectorSizes.end());
1885+
}
1886+
if (!inputScalableVecDims.empty()) {
1887+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1888+
inputScalableVecDims.begin() +
1889+
sourceShape.size());
1890+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1891+
sourceShape.size(),
1892+
inputScalableVecDims.end());
1893+
} else {
1894+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1895+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1896+
}
1897+
1898+
// writeVectorSizes is the shape of the vector that will be used to do final
18671899
// write on the destination tensor. It is set like this: Let's say the
18681900
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18691901
// Thus:
1870-
// 1. vectorSizes = sourceShape.take_front(N)
1871-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1902+
// 1. writeVectorSizes = sourceShape.take_front(N)
1903+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
18721904
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18731905
// innerTiles attribute value.
1874-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1875-
if (vectorSizes.empty()) {
1876-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1906+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1907+
if (writeVectorSizes.empty()) {
1908+
if (ShapedType::isDynamicShape(sourceShape))
1909+
return failure();
1910+
1911+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
18771912
if (!outerDimsPerm.empty())
1878-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1913+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
18791914
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1880-
vectorSizes[pos] *= innerTiles[i];
1915+
writeVectorSizes[pos] *= innerTiles[i];
18811916

18821917
useInBoundsInsteadOfMasking = true;
18831918
}
@@ -1901,17 +1936,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19011936
// After applying outer_dims_perm: [8, 16]
19021937
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19031938

1904-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1905-
1906-
for (auto [index, size] : enumerate(innerTiles)) {
1907-
readVectorSizes[innerDimPos[index]] =
1908-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1909-
}
1910-
if (!outerDimsPerm.empty()) {
1911-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1939+
if (readVectorSizes.empty()) {
1940+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1941+
// sizes. Note, this will only work when all sizes are static.
1942+
readVectorSizes = writeVectorSizes;
1943+
for (auto [index, size] : enumerate(innerTiles)) {
1944+
readVectorSizes[innerDimPos[index]] =
1945+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1946+
}
1947+
if (!outerDimsPerm.empty()) {
1948+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1949+
}
1950+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1951+
sourceShape.end());
19121952
}
1913-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1914-
sourceShape.end());
19151953

19161954
Location loc = unpackOp->getLoc();
19171955

@@ -1923,7 +1961,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19231961
// to shape of source, then a mask is necessary.
19241962
Value readResult = vector::createReadOrMaskedRead(
19251963
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1926-
/*useInBoundsInsteadOfMasking=*/false);
1964+
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19271965

19281966
PackingMetadata packMetadata;
19291967
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1942,15 +1980,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19421980
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19431981
stripMineTensorType, packMetadata.reassociations);
19441982
mlir::VectorType vecCollapsedType =
1945-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1983+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
1984+
writeScalableVectorFlags);
19461985
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
19471986
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
19481987

1949-
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
1988+
// writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19501989
// otherwise the validator complains that the mask size is invalid.
1951-
SmallVector<int64_t> writeVectorSizes(
1990+
// FIXME: We should not override write-vector-sizes like this.
1991+
SmallVector<int64_t> writeVectorSizesFinal(
19521992
unpackOp.getDestType().hasStaticShape()
1953-
? vectorSizes
1993+
? writeVectorSizes
19541994
: shapeCastOp.getResultVectorType().getShape());
19551995
Operation *write = createWriteOrMaskedWrite(
19561996
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1981,7 +2021,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19812021
assert(succeeded(status) && "failed to reify result shapes");
19822022
auto maskedRead = vector::createReadOrMaskedRead(
19832023
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1984-
/*useInBoundsInsteadOfMasking=*/false);
2024+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
19852025

19862026
// Create Xfer write Op
19872027
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2065,6 +2105,9 @@ static LogicalResult
20652105
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20662106
ArrayRef<int64_t> inputVectorSizes) {
20672107

2108+
// FIXME!!!
2109+
return success();
2110+
20682111
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20692112
return !getConstantIntValue(res).has_value();
20702113
})) {
@@ -2401,6 +2444,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24012444
LDBG() << "pad value is not constant: " << packOp;
24022445
return failure();
24032446
}
2447+
24042448
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24052449
bool satisfyEmptyCond = true;
24062450
if (inputVectorSizes.empty()) {
@@ -2479,12 +2523,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24792523
if (numOfScalableDims == 0)
24802524
return success();
24812525

2526+
// TODO: Check the following!
24822527
auto linalgOp = dyn_cast<LinalgOp>(op);
24832528

2484-
// Cond 1: There's been no need for scalable vectorisation of
2485-
// non-linalg Ops so far
2486-
if (!linalgOp)
2487-
return failure();
2529+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2530+
// exception of UnpackOp for which there is a dedicated hook.
2531+
if (!linalgOp) {
2532+
return isa<linalg::UnPackOp>(op) ? success() : failure();
2533+
}
24882534

24892535
// Cond 2: There's been no need for more than 2 scalable dims so far
24902536
if (numOfScalableDims > 2)
@@ -2582,7 +2628,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25822628
isa<linalg::MatmulTransposeAOp>(op) ||
25832629
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25842630
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2585-
hasReductionIterator(linalgOp));
2631+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
25862632
}
25872633

25882634
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2715,7 +2761,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27152761
})
27162762
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27172763
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2718-
inputVectorSizes, results);
2764+
inputVectorSizes,
2765+
inputScalableVecDims, results);
27192766
})
27202767
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27212768
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3107,7 +3154,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31073154
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31083155
Value read = mlir::vector::createReadOrMaskedRead(
31093156
rewriter, loc, source, vecType.getShape(), padValue,
3110-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3157+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3158+
/*inputScalableVecSizes=*/{});
31113159

31123160
// Create write
31133161
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
279279
// Attempt to unroll until targetRank or the first scalable dimension (which
280280
// cannot be unrolled).
281281
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
282-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
283-
auto it = llvm::find(scalableDimsToUnroll, true);
284-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
282+
auto inputScalableVecDimsToUnroll =
283+
vType.getScalableDims().drop_back(targetRank);
284+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
285+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
285286
if (firstScalableDim == 0)
286287
return {};
287288
// All scalable dimensions should be removed now.
288-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
289-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
289+
inputScalableVecDimsToUnroll =
290+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
291+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
290292
"unexpected leading scalable dimension");
291293
// Create an unroll iterator for leading dimensions.
292294
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
319321
ArrayRef<int64_t> inputVectorSizes,
320322
Value padValue,
321323
bool useInBoundsInsteadOfMasking,
322-
ArrayRef<bool> scalableDims) {
324+
ArrayRef<bool> inputScalableVecDims) {
323325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
324326
"invalid input vector sizes");
325327
auto sourceShapedType = cast<ShapedType>(source.getType());
326328
auto sourceShape = sourceShapedType.getShape();
327329
assert(sourceShape.size() == inputVectorSizes.size() &&
328330
"expected same ranks.");
329-
auto vectorType =
330-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
356358
? memref::getMixedSizes(builder, loc, source)
357359
: tensor::getMixedSizes(builder, loc, source);
358360

359-
auto maskType =
360-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
361+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
362+
inputScalableVecDims);
361363
Value mask =
362364
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
363365
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)