Skip to content

Commit 1992e9a

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent 77363fb commit 1992e9a

File tree

4 files changed

+182
-53
lines changed

4 files changed

+182
-53
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805
inputShape[innerDimsPos[idx]] *= size;
18061806
auto maskedRead = vector::createReadOrMaskedRead(
18071807
rewriter, loc, packOp.getSource(), inputShape, padValue,
1808-
useInBoundsInsteadOfMasking);
1808+
useInBoundsInsteadOfMasking,
1809+
/*inputScalableVecSizes=*/{});
18091810

18101811
// Create ShapeCastOp.
18111812
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1885,11 +1886,19 @@ static VectorType getCollapsedVecType(VectorType type,
18851886
/// vector::TransferWriteOp. - Write the result vector back to the destination
18861887
/// tensor.
18871888
/// If the vector sizes are not provided:
1888-
/// * the vector sizes are determined by the input operand and attributes,
1889-
/// * update the inBounds attribute instead of masking.
1889+
/// Vectorize `linalg.unpack %src into %dest` as:
1890+
/// // Reads a vector from the source tensor
1891+
/// %read = vector.transfer_read %src
1892+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1893+
/// %tr = vector.transpose %read
1894+
/// // Reshape the data based on the target
1895+
/// %sc = vector.shape_cast %tr
1896+
/// // Write the result vector to the destination tensor.
1897+
/// vector.transfer_write %sc into %dest
18901898
static LogicalResult
18911899
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18921900
ArrayRef<int64_t> inputVectorSizes,
1901+
ArrayRef<bool> inputScalableVecDims,
18931902
SmallVectorImpl<Value> &newResults) {
18941903

18951904
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1906,25 +1915,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19061915

19071916
auto destSize = unpackOp.getDestRank();
19081917

1909-
if (!inputVectorSizes.empty())
1910-
assert(inputVectorSizes.size() == destSize &&
1918+
if (!inputVectorSizes.empty()) {
1919+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
19111920
"Incorrect number of input vector sizes");
1921+
}
1922+
1923+
SmallVector<bool> readScalableVectorFlags;
1924+
SmallVector<bool> writeScalableVectorFlags;
1925+
SmallVector<int64_t> readVectorSizes;
1926+
SmallVector<int64_t> writeVectorSizes;
1927+
1928+
// Split input-vector-sizes into vector sizes for the read and write
1929+
// operations.
1930+
if (!inputVectorSizes.empty()) {
1931+
readVectorSizes.append(inputVectorSizes.begin(),
1932+
inputVectorSizes.begin() + sourceShape.size());
1933+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1934+
inputVectorSizes.end());
1935+
}
1936+
if (!inputScalableVecDims.empty()) {
1937+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1938+
inputScalableVecDims.begin() +
1939+
sourceShape.size());
1940+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1941+
sourceShape.size(),
1942+
inputScalableVecDims.end());
1943+
} else {
1944+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1945+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1946+
}
19121947

1913-
// vectorSizes is the shape of the vector that will be used to do final
1948+
// writeVectorSizes is the shape of the vector that will be used to do final
19141949
// write on the destination tensor. It is set like this: Let's say the
19151950
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
19161951
// Thus:
1917-
// 1. vectorSizes = sourceShape.take_front(N)
1918-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1952+
// 1. writeVectorSizes = sourceShape.take_front(N)
1953+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
19191954
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
19201955
// innerTiles attribute value.
1921-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1922-
if (vectorSizes.empty()) {
1923-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1956+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1957+
if (writeVectorSizes.empty()) {
1958+
if (ShapedType::isDynamicShape(sourceShape))
1959+
return failure();
1960+
1961+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
19241962
if (!outerDimsPerm.empty())
1925-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1963+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
19261964
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1927-
vectorSizes[pos] *= innerTiles[i];
1965+
writeVectorSizes[pos] *= innerTiles[i];
19281966

19291967
useInBoundsInsteadOfMasking = true;
19301968
}
@@ -1948,17 +1986,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19481986
// After applying outer_dims_perm: [8, 16]
19491987
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19501988

1951-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1952-
1953-
for (auto [index, size] : enumerate(innerTiles)) {
1954-
readVectorSizes[innerDimPos[index]] =
1955-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1956-
}
1957-
if (!outerDimsPerm.empty()) {
1958-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1989+
if (readVectorSizes.empty()) {
1990+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1991+
// sizes. Note, this will only work when all sizes are static.
1992+
readVectorSizes = writeVectorSizes;
1993+
for (auto [index, size] : enumerate(innerTiles)) {
1994+
readVectorSizes[innerDimPos[index]] =
1995+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1996+
}
1997+
if (!outerDimsPerm.empty()) {
1998+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1999+
}
2000+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
2001+
sourceShape.end());
19592002
}
1960-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1961-
sourceShape.end());
19622003

19632004
Location loc = unpackOp->getLoc();
19642005

@@ -1970,7 +2011,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19702011
// to shape of source, then a mask is necessary.
19712012
Value readResult = vector::createReadOrMaskedRead(
19722013
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1973-
/*useInBoundsInsteadOfMasking=*/false);
2014+
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19742015

19752016
PackingMetadata packMetadata;
19762017
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -2016,7 +2057,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20162057
assert(succeeded(status) && "failed to reify result shapes");
20172058
auto maskedRead = vector::createReadOrMaskedRead(
20182059
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2019-
/*useInBoundsInsteadOfMasking=*/false);
2060+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
20202061

20212062
// Create Xfer write Op
20222063
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2100,6 +2141,9 @@ static LogicalResult
21002141
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21012142
ArrayRef<int64_t> inputVectorSizes) {
21022143

2144+
// FIXME!!!
2145+
return success();
2146+
21032147
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
21042148
return !getConstantIntValue(res).has_value();
21052149
})) {
@@ -2436,6 +2480,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24362480
LDBG() << "pad value is not constant: " << packOp;
24372481
return failure();
24382482
}
2483+
24392484
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24402485
bool satisfyEmptyCond = true;
24412486
if (inputVectorSizes.empty()) {
@@ -2514,12 +2559,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
25142559
if (numOfScalableDims == 0)
25152560
return success();
25162561

2562+
// TODO: Check the following!
25172563
auto linalgOp = dyn_cast<LinalgOp>(op);
25182564

2519-
// Cond 1: There's been no need for scalable vectorisation of
2520-
// non-linalg Ops so far
2521-
if (!linalgOp)
2522-
return failure();
2565+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2566+
// exception of UnpackOp for which there is a dedicated hook.
2567+
if (!linalgOp) {
2568+
return isa<linalg::UnPackOp>(op) ? success() : failure();
2569+
}
25232570

25242571
// Cond 2: There's been no need for more than 2 scalable dims so far
25252572
if (numOfScalableDims > 2)
@@ -2617,7 +2664,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26172664
isa<linalg::MatmulTransposeAOp>(op) ||
26182665
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26192666
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2620-
hasReductionIterator(linalgOp));
2667+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
26212668
}
26222669

26232670
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2750,7 +2797,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27502797
})
27512798
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27522799
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2753-
inputVectorSizes, results);
2800+
inputVectorSizes,
2801+
inputScalableVecDims, results);
27542802
})
27552803
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27562804
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3190,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31423190
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31433191
Value read = mlir::vector::createReadOrMaskedRead(
31443192
rewriter, loc, source, vecType.getShape(), padValue,
3145-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3193+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3194+
/*inputScalableVecSizes=*/{});
31463195

31473196
// Create write
31483197
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
279279
// Attempt to unroll until targetRank or the first scalable dimension (which
280280
// cannot be unrolled).
281281
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
282-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
283-
auto it = llvm::find(scalableDimsToUnroll, true);
284-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
282+
auto inputScalableVecDimsToUnroll =
283+
vType.getScalableDims().drop_back(targetRank);
284+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
285+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
285286
if (firstScalableDim == 0)
286287
return {};
287288
// All scalable dimensions should be removed now.
288-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
289-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
289+
inputScalableVecDimsToUnroll =
290+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
291+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
290292
"unexpected leading scalable dimension");
291293
// Create an unroll iterator for leading dimensions.
292294
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
319321
ArrayRef<int64_t> inputVectorSizes,
320322
Value padValue,
321323
bool useInBoundsInsteadOfMasking,
322-
ArrayRef<bool> scalableDims) {
324+
ArrayRef<bool> inputScalableVecDims) {
323325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
324326
"invalid input vector sizes");
325327
auto sourceShapedType = cast<ShapedType>(source.getType());
326328
auto sourceShape = sourceShapedType.getShape();
327329
assert(sourceShape.size() == inputVectorSizes.size() &&
328330
"expected same ranks.");
329-
auto vectorType =
330-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
356358
? memref::getMixedSizes(builder, loc, source)
357359
: tensor::getMixedSizes(builder, loc, source);
358360

359-
auto maskType =
360-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
361+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
362+
inputScalableVecDims);
361363
Value mask =
362364
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
363365
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)