@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805 inputShape[innerDimsPos[idx]] *= size;
18061806 auto maskedRead = vector::createReadOrMaskedRead (
18071807 rewriter, loc, packOp.getSource (), inputShape, padValue,
1808- useInBoundsInsteadOfMasking);
1808+ useInBoundsInsteadOfMasking,
1809+ /* inputScalableVecSizes=*/ {});
18091810
18101811 // Create ShapeCastOp.
18111812 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1840,6 +1841,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18401841// /
18411842// / When collapsing scalable flags, conservatively avoids cases with two
18421843// / scalable dims. We could re-visit this in the future.
1844+ // /
1845+ // / If the vector sizes are not provided:
1846+ // / * the vector sizes are determined by the input operand and attributes,
1847+ // / * update the inBounds attribute instead of masking.
18431848static VectorType getCollapsedVecType (VectorType type,
18441849 ArrayRef<AffineMap> reassociation) {
18451850 assert (type.getNumScalableDims () < 2 &&
@@ -1878,11 +1883,19 @@ static VectorType getCollapsedVecType(VectorType type,
18781883// / vector::TransferWriteOp. - Write the result vector back to the destination
18791884// / tensor.
18801885// / If the vector sizes are not provided:
1881- // / * the vector sizes are determined by the input operand and attributes,
1882- // / * update the inBounds attribute instead of masking.
1886+ // / Vectorize `linalg.unpack %src into %dest` as:
1887+ // / // Reads a vector from the source tensor
1888+ // / %read = vector.transfer_read %src
1889+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1890+ // / %tr = vector.transpose %read
1891+ // / // Reshape the data based on the target
1892+ // / %sc = vector.shape_cast %tr
1893+ // / // Write the result vector to the destination tensor.
1894+ // / vector.transfer_write %sc into %dest
18831895static LogicalResult
18841896vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18851897 ArrayRef<int64_t > inputVectorSizes,
1898+ ArrayRef<bool > inputScalableVecDims,
18861899 SmallVectorImpl<Value> &newResults) {
18871900
18881901 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1899,25 +1912,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18991912
19001913 auto destSize = unpackOp.getDestRank ();
19011914
1902- if (!inputVectorSizes.empty ())
1903- assert (inputVectorSizes.size () == destSize &&
1915+ if (!inputVectorSizes.empty ()) {
1916+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
19041917 " Incorrect number of input vector sizes" );
1918+ }
1919+
1920+ SmallVector<bool > readScalableVectorFlags;
1921+ SmallVector<bool > writeScalableVectorFlags;
1922+ SmallVector<int64_t > readVectorSizes;
1923+ SmallVector<int64_t > writeVectorSizes;
19051924
1906- // vectorSizes is the shape of the vector that will be used to do final
1925+ // Split input-vector-sizes into vector sizes for the read and write
1926+ // operations.
1927+ if (!inputVectorSizes.empty ()) {
1928+ readVectorSizes.append (inputVectorSizes.begin (),
1929+ inputVectorSizes.begin () + sourceShape.size ());
1930+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1931+ inputVectorSizes.end ());
1932+ }
1933+ if (!inputScalableVecDims.empty ()) {
1934+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1935+ inputScalableVecDims.begin () +
1936+ sourceShape.size ());
1937+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1938+ sourceShape.size (),
1939+ inputScalableVecDims.end ());
1940+ } else {
1941+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1942+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1943+ }
1944+
1945+ // writeVectorSizes is the shape of the vector that will be used to do final
19071946 // write on the destination tensor. It is set like this: Let's say the
19081947 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
19091948 // Thus:
1910- // 1. vectorSizes = sourceShape.take_front(N)
1911- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1949+ // 1. writeVectorSizes = sourceShape.take_front(N)
1950+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
19121951 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
19131952 // innerTiles attribute value.
1914- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1915- if (vectorSizes.empty ()) {
1916- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1953+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1954+ if (writeVectorSizes.empty ()) {
1955+ if (ShapedType::isDynamicShape (sourceShape))
1956+ return failure ();
1957+
1958+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
19171959 if (!outerDimsPerm.empty ())
1918- applyPermutationToVector (vectorSizes , outerDimsPerm);
1960+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
19191961 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1920- vectorSizes [pos] *= innerTiles[i];
1962+ writeVectorSizes [pos] *= innerTiles[i];
19211963
19221964 useInBoundsInsteadOfMasking = true ;
19231965 }
@@ -1941,17 +1983,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19411983 // After applying outer_dims_perm: [8, 16]
19421984 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19431985
1944- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1945-
1946- for (auto [index, size] : enumerate(innerTiles)) {
1947- readVectorSizes[innerDimPos[index]] =
1948- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1949- }
1950- if (!outerDimsPerm.empty ()) {
1951- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1986+ if (readVectorSizes.empty ()) {
1987+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1988+ // sizes. Note, this will only work when all sizes are static.
1989+ readVectorSizes = writeVectorSizes;
1990+ for (auto [index, size] : enumerate(innerTiles)) {
1991+ readVectorSizes[innerDimPos[index]] =
1992+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1993+ }
1994+ if (!outerDimsPerm.empty ()) {
1995+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1996+ }
1997+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1998+ sourceShape.end ());
19521999 }
1953- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1954- sourceShape.end ());
19552000
19562001 Location loc = unpackOp->getLoc ();
19572002
@@ -1963,7 +2008,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19632008 // to shape of source, then a mask is necessary.
19642009 Value readResult = vector::createReadOrMaskedRead (
19652010 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1966- /* useInBoundsInsteadOfMasking=*/ false );
2011+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
19672012
19682013 PackingMetadata packMetadata;
19692014 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -2009,7 +2054,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20092054 assert (succeeded (status) && " failed to reify result shapes" );
20102055 auto maskedRead = vector::createReadOrMaskedRead (
20112056 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
2012- /* useInBoundsInsteadOfMasking=*/ false );
2057+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
20132058
20142059 // Create Xfer write Op
20152060 Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2093,6 +2138,9 @@ static LogicalResult
20932138vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20942139 ArrayRef<int64_t > inputVectorSizes) {
20952140
2141+ // FIXME!!!
2142+ return success ();
2143+
20962144 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20972145 return !getConstantIntValue (res).has_value ();
20982146 })) {
@@ -2429,6 +2477,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24292477 LDBG () << " pad value is not constant: " << packOp;
24302478 return failure ();
24312479 }
2480+
24322481 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24332482 bool satisfyEmptyCond = true ;
24342483 if (inputVectorSizes.empty ()) {
@@ -2507,12 +2556,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
25072556 if (numOfScalableDims == 0 )
25082557 return success ();
25092558
2559+ // TODO: Check the following!
25102560 auto linalgOp = dyn_cast<LinalgOp>(op);
25112561
2512- // Cond 1: There's been no need for scalable vectorisation of
2513- // non-linalg Ops so far
2514- if (!linalgOp)
2515- return failure ();
2562+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2563+ // exception of UnpackOp for which there is a dedicated hook.
2564+ if (!linalgOp) {
2565+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2566+ }
25162567
25172568 // Cond 2: There's been no need for more than 2 scalable dims so far
25182569 if (numOfScalableDims > 2 )
@@ -2610,7 +2661,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26102661 isa<linalg::MatmulTransposeAOp>(op) ||
26112662 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26122663 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2613- hasReductionIterator (linalgOp));
2664+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
26142665}
26152666
26162667LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2743,7 +2794,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27432794 })
27442795 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27452796 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2746- inputVectorSizes, results);
2797+ inputVectorSizes,
2798+ inputScalableVecDims, results);
27472799 })
27482800 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27492801 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3135,7 +3187,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31353187 vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
31363188 Value read = mlir::vector::createReadOrMaskedRead (
31373189 rewriter, loc, source, vecType.getShape (), padValue,
3138- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3190+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3191+ /* inputScalableVecSizes=*/ {});
31393192
31403193 // Create write
31413194 auto writeIndices =
0 commit comments