Skip to content

Commit 656f7ef

Browse files
committed
Final tweaks
1 parent 6548876 commit 656f7ef

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,7 +1887,7 @@ static VectorType getCollapsedVecType(VectorType type,
18871887
/// sizes required here.
18881888
///
18891889
/// If the vector sizes are not provided:
1890-
/// * the vector sizes are determined by the operands,
1890+
/// * the vector sizes are determined from the input tensor static shape.
18911891
/// * the inBounds attribute is used instead of masking.
18921892
///
18931893
/// EXAMPLE (no vector sizes):
@@ -1899,7 +1899,14 @@ static VectorType getCollapsedVecType(VectorType type,
18991899
/// ```
19001900
/// is vectorized as:
19011901
/// ```
1902-
/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1902+
/// %read = vector.transfer_read %src
1903+
/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904+
/// %tr = vector.transpose %read, [0, 2, 1, 3]
1905+
/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906+
/// %sc = vector.shape_cast %tr
1907+
/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1908+
/// %vector = vector.transfer_write %sc into %dest
1909+
/// : vector<8x8xf32>, tensor<8x8xf32>
19031910
/// ```
19041911
static LogicalResult
19051912
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
@@ -1920,60 +1927,51 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19201927
RankedTensorType unpackTensorType = unpackOp.getSourceType();
19211928

19221929
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1923-
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
19241930
bool useInBoundsInsteadOfMasking = false;
19251931

19261932
Location loc = unpackOp->getLoc();
19271933

1928-
// 1. Obtain vector sizes for the read and write operations.
1929-
SmallVector<int64_t> readVectorSizes;
1930-
SmallVector<bool> readScalableVectorFlags;
1934+
// Obtain vector sizes for the read operation.
1935+
SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1936+
SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
19311937

1932-
if (!inputVectorSizes.empty()) {
1933-
// CASE 1.1: Vector sizes are user-specified.
1934-
readVectorSizes.assign(inputVectorSizes.begin(),
1935-
inputVectorSizes.begin() + sourceShape.size());
1936-
readScalableVectorFlags.assign(inputScalableVecDims.begin(),
1937-
inputScalableVecDims.begin() +
1938-
sourceShape.size());
1939-
} else {
1940-
// CASE 1.2: Vector sizes are inferred from the static input tensor
1941-
// shapes.
1942-
if (ShapedType::isDynamicShape(destShape) ||
1943-
ShapedType::isDynamicShape(sourceShape))
1938+
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939+
if (inputVectorSizes.empty()) {
1940+
if (ShapedType::isDynamicShape(sourceShape))
19441941
return failure();
19451942

19461943
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
19471944
useInBoundsInsteadOfMasking = true;
19481945
}
19491946

1950-
// 2. Generate the read operation.
1947+
// -- Generate the read operation --
19511948
auto padValue = arith::ConstantOp::create(
19521949
rewriter, loc,
19531950
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
19541951
Value readResult = vector::createReadOrMaskedRead(
19551952
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1956-
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
1953+
useInBoundsInsteadOfMasking, readScalableVectorFlags);
19571954

1958-
// 3. Generate the transpose operation.
1955+
// -- Generate the transpose operation --
19591956
PackingMetadata packMetadata;
19601957
SmallVector<int64_t> lastDimToInsertPosPerm =
19611958
getUnPackInverseSrcPerm(unpackOp, packMetadata);
19621959
vector::TransposeOp transposeOp = vector::TransposeOp::create(
19631960
rewriter, loc, readResult, lastDimToInsertPosPerm);
19641961

1965-
// 3. Generate the shape_cast operation.
1962+
// -- Generate the shape_cast operation --
19661963
VectorType collapsedVecType = getCollapsedVecType(
19671964
transposeOp.getType(),
19681965
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
19691966
rewriter.getContext(), packMetadata.reassociations)));
19701967
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
19711968
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
19721969

1973-
// 4. Generate the write operation.
1970+
// -- Generate the write operation --
19741971
Operation *write = createWriteOrMaskedWrite(
19751972
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
19761973
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1974+
19771975
newResults.push_back(write->getResult(0));
19781976
return success();
19791977
}

0 commit comments

Comments
 (0)