Skip to content

Commit b073854

Browse files
committed
Simplify code as per comments from HanHan
1 parent b8dddce commit b073854

File tree

2 files changed

+60
-93
lines changed

2 files changed

+60
-93
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 59 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,10 +1841,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18411841
///
18421842
/// When collapsing scalable flags, conservatively avoids cases with two
18431843
/// scalable dims. We could re-visit this in the future.
1844-
///
1845-
/// If the vector sizes are not provided:
1846-
/// * the vector sizes are determined by the input operand and attributes,
1847-
/// * update the inBounds attribute instead of masking.
18481844
static VectorType getCollapsedVecType(VectorType type,
18491845
ArrayRef<AffineMap> reassociation) {
18501846
assert(type.getNumScalableDims() < 2 &&
@@ -1876,22 +1872,35 @@ static VectorType getCollapsedVecType(VectorType type,
18761872
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
18771873
}
18781874

1879-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1880-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1881-
/// vector::TransposeOp - Transpose the Source tensor
1882-
/// ShapeCastOp - Reshape the data based on the target.
1883-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1884-
/// tensor.
1885-
/// If the vector sizes are not provided:
1886-
/// Vectorize `linalg.unpack %src into %dest` as:
1887-
/// // Reads a vector from the source tensor
1888-
/// %read = vector.transfer_read %src
1889-
/// // Transpose %read as specified in `outer_dims_perm` attribute
1890-
/// %tr = vector.transpose %read
1891-
/// // Reshape the data based on the target
1892-
/// %sc = vector.shape_cast %tr
1893-
/// // Write the result vector to the destination tensor.
1894-
/// vector.transfer_write %sc into %dest
1875+
/// Vectorize `linalg.unpack` into:
1876+
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1877+
///
1878+
/// The input-vector-sizes specify both the read and the write vector
1879+
/// sizes and are passed as one array covering both operations, i.e.:
1880+
///
1881+
/// input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1882+
/// \ / \ /
1883+
/// read-sizes write-sizes
1884+
///
1885+
/// (for brefity, in the diagram,
1886+
/// * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1887+
/// )
1888+
///
1889+
/// If the vector sizes are not provided:
1890+
/// * the vector sizes are determined by the operands,
1891+
/// * the inBounds attribute is used instead of masking.
1892+
///
1893+
/// EXAMPLE (no vector sizes):
1894+
/// ```
1895+
/// %unpack = linalg.unpack %src
1896+
/// inner_dims_pos = [0, 1]
1897+
/// inner_tiles = [8, 8]
1898+
/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899+
/// ```
1900+
/// is vectorized as:
1901+
/// ```
1902+
/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1903+
/// ```
18951904
static LogicalResult
18961905
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18971906
ArrayRef<int64_t> inputVectorSizes,
@@ -1911,22 +1920,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19111920

19121921
RankedTensorType unpackTensorType = unpackOp.getSourceType();
19131922

1914-
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1915-
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
19161923
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1924+
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
19171925
bool useInBoundsInsteadOfMasking = false;
1918-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
19191926

1920-
auto destSize = unpackOp.getDestRank();
1927+
Location loc = unpackOp->getLoc();
19211928

1922-
// 1. Obtain vector sizes for the read and write operation.s
1929+
// 1. Obtain vector sizes for the read and write operations.
19231930
SmallVector<int64_t> readVectorSizes;
19241931
SmallVector<int64_t> writeVectorSizes;
19251932
SmallVector<bool> readScalableVectorFlags;
19261933
SmallVector<bool> writeScalableVectorFlags;
19271934

1928-
// CASE 1: Vector sizes are user-specified.
1929-
// 1.0 This is the trivial case, simply split the input vector sizes.
1935+
// CASE 1.1: Vector sizes are user-specified.
19301936
if (!inputVectorSizes.empty()) {
19311937
readVectorSizes.append(inputVectorSizes.begin(),
19321938
inputVectorSizes.begin() + sourceShape.size());
@@ -1940,83 +1946,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19401946
inputScalableVecDims.end());
19411947
}
19421948

1943-
// CASE 2: Vector sizes have to be inferred.
1944-
//
1945-
// 1.1 Infer vector sizes for the write operation.
1946-
//
1947-
// Let:
1948-
// * rank(source tensor) = 'M'
1949-
// * rank(dest tensor) = 'N',
1950-
// and N <= M. The steps are:
1951-
// 1. writeVectorSizes = sourceShape.take_front(N)
1952-
// 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1953-
// by the corresponding values from the `inner_tiles` attribute value.
1954-
// 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1955-
//
1956-
// Note, this will only work when all sizes are static!
1949+
// CASE 1. 2: Vector sizes have to be inferred.
19571950
if (writeVectorSizes.empty()) {
1958-
if (ShapedType::isDynamicShape(sourceShape))
1951+
if (ShapedType::isDynamicShape(destShape) ||
1952+
ShapedType::isDynamicShape(sourceShape))
19591953
return failure();
19601954

1961-
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
1962-
if (!outerDimsPerm.empty())
1963-
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
1964-
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1965-
writeVectorSizes[pos] *= innerTiles[i];
1966-
1955+
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1956+
writeVectorSizes.assign(destShape.begin(), destShape.end());
19671957
useInBoundsInsteadOfMasking = true;
19681958
}
19691959

1970-
// 1.2 Infer vector sizes for the read operation.
1971-
//
1972-
// The steps are:
1973-
// 1. readVectorSizes = writeVectorSizes
1974-
// 2. Take readVectorSizes from 1. and divide all locations pointed by
1975-
// the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1976-
// 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1977-
// 4. Append the remaining sizes from the source tensor.
1978-
//
1979-
// Note, this will only work when all sizes are static!
1980-
if (readVectorSizes.empty()) {
1981-
readVectorSizes = writeVectorSizes;
1982-
for (auto [index, size] : enumerate(innerTiles)) {
1983-
readVectorSizes[innerDimPos[index]] =
1984-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1985-
}
1986-
if (!outerDimsPerm.empty()) {
1987-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1988-
}
1989-
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1990-
sourceShape.end());
1991-
}
1992-
1993-
Location loc = unpackOp->getLoc();
1994-
1960+
// 2. Generate the read operation.
19951961
auto padValue = arith::ConstantOp::create(
19961962
rewriter, loc,
19971963
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1998-
1999-
// Read result, mask if necessary. If transferReadOp shape is not equal
2000-
// to shape of source, then a mask is necessary.
20011964
Value readResult = vector::createReadOrMaskedRead(
20021965
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
20031966
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
20041967

1968+
// 3. Generate the transpose operation.
20051969
PackingMetadata packMetadata;
20061970
SmallVector<int64_t> lastDimToInsertPosPerm =
20071971
getUnPackInverseSrcPerm(unpackOp, packMetadata);
2008-
// Transpose the appropriate rows to match output.
20091972
vector::TransposeOp transposeOp = vector::TransposeOp::create(
20101973
rewriter, loc, readResult, lastDimToInsertPosPerm);
20111974

2012-
// Collapse the vector to the size required by result.
1975+
// 3. Generate the shape_cast operation.
20131976
VectorType collapsedVecType = getCollapsedVecType(
20141977
transposeOp.getType(),
20151978
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
20161979
rewriter.getContext(), packMetadata.reassociations)));
20171980
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
20181981
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
20191982

1983+
// 4. Generate the write operation.
20201984
Operation *write = createWriteOrMaskedWrite(
20211985
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
20221986
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
@@ -2144,24 +2108,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21442108
if (!inputVectorSizes.empty()) {
21452109
if (inputVectorSizes.size() !=
21462110
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2147-
LDBG("Incorrect number of input vector sizes");
2111+
LDBG() << "Incorrect number of input vector sizes";
21482112
return failure();
21492113
}
21502114
}
21512115

2152-
// Check the vector sizes for the write operation.
2116+
// Check the vector sizes for the read operation.
21532117
if (failed(vector::isValidMaskedInputVector(
2154-
unpackOp.getDestType().getShape(),
2155-
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2156-
LDBG("Incorrect number of input vector sizes");
2118+
unpackOp.getSourceType().getShape(),
2119+
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2120+
LDBG() << "Invalid vector sizes for the read operation";
21572121
return failure();
21582122
}
21592123

2160-
// Check the vector sizes for the read operation.
2124+
// Check the vector sizes for the write operation.
21612125
if (failed(vector::isValidMaskedInputVector(
2162-
unpackOp.getSourceType().getShape(),
2163-
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2164-
LDBG("Incorrect number of input vector sizes");
2126+
unpackOp.getDestType().getShape(),
2127+
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2128+
LDBG() << "Invalid vector sizes for the write operation";
21652129
return failure();
21662130
}
21672131

@@ -2551,8 +2515,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
25512515
return success();
25522516
}
25532517

2554-
/// Preconditions for scalable vectors. This is quite restrictive - it models
2555-
/// the fact that in practice we would only make selected dimensions scalable.
2518+
/// Preconditions for scalable vectors.
2519+
///
2520+
/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2521+
/// models the fact that in practice we would only make selected dimensions
2522+
/// scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2523+
/// unconditionally - we are yet to identify meaningful conditions.
25562524
static LogicalResult
25572525
vectorizeScalableVectorPrecondition(Operation *op,
25582526
ArrayRef<int64_t> inputVectorSizes,
@@ -2571,7 +2539,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25712539
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
25722540
// exception of UnpackOp for which there is a dedicated hook.
25732541
if (!linalgOp) {
2574-
return isa<linalg::UnPackOp>(op) ? success() : failure();
2542+
return success(isa<linalg::UnPackOp>(op));
25752543
}
25762544

25772545
// Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2670,7 +2638,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26702638
isa<linalg::MatmulTransposeAOp>(op) ||
26712639
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26722640
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2673-
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
2641+
hasReductionIterator(linalgOp));
26742642
}
26752643

26762644
LogicalResult mlir::linalg::vectorizeOpPrecondition(

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
387387
staticSize <= inputSize;
388388
})) {
389389
LDBG() << "Input vector sizes must be greater than or equal to iteration "
390-
"space "
391-
"static sizes";
390+
"space static sizes";
392391
return failure();
393392
}
394393
return success();

0 commit comments

Comments
 (0)