@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805 inputShape[innerDimsPos[idx]] *= size;
18061806 auto maskedRead = vector::createReadOrMaskedRead (
18071807 rewriter, loc, packOp.getSource (), inputShape, padValue,
1808- useInBoundsInsteadOfMasking);
1808+ useInBoundsInsteadOfMasking,
1809+ /* inputScalableVecSizes=*/ {});
18091810
18101811 // Create ShapeCastOp.
18111812 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1885,11 +1886,19 @@ static VectorType getCollapsedVecType(VectorType type,
18851886// / vector::TransferWriteOp. - Write the result vector back to the destination
18861887// / tensor.
18871888// / If the vector sizes are not provided:
1888- // / * the vector sizes are determined by the input operand and attributes,
1889- // / * update the inBounds attribute instead of masking.
1889+ // / Vectorize `linalg.unpack %src into %dest` as:
1890+ // / // Reads a vector from the source tensor
1891+ // / %read = vector.transfer_read %src
1892+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1893+ // / %tr = vector.transpose %read
1894+ // / // Reshape the data based on the target
1895+ // / %sc = vector.shape_cast %tr
1896+ // / // Write the result vector to the destination tensor.
1897+ // / vector.transfer_write %sc into %dest
18901898static LogicalResult
18911899vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18921900 ArrayRef<int64_t > inputVectorSizes,
1901+ ArrayRef<bool > inputScalableVecDims,
18931902 SmallVectorImpl<Value> &newResults) {
18941903
18951904 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1906,25 +1915,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19061915
19071916 auto destSize = unpackOp.getDestRank ();
19081917
1909- if (!inputVectorSizes.empty ())
1910- assert (inputVectorSizes.size () == destSize &&
1918+ if (!inputVectorSizes.empty ()) {
1919+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
19111920 " Incorrect number of input vector sizes" );
1921+ }
1922+
1923+ SmallVector<bool > readScalableVectorFlags;
1924+ SmallVector<bool > writeScalableVectorFlags;
1925+ SmallVector<int64_t > readVectorSizes;
1926+ SmallVector<int64_t > writeVectorSizes;
1927+
1928+ // Split input-vector-sizes into vector sizes for the read and write
1929+ // operations.
1930+ if (!inputVectorSizes.empty ()) {
1931+ readVectorSizes.append (inputVectorSizes.begin (),
1932+ inputVectorSizes.begin () + sourceShape.size ());
1933+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1934+ inputVectorSizes.end ());
1935+ }
1936+ if (!inputScalableVecDims.empty ()) {
1937+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1938+ inputScalableVecDims.begin () +
1939+ sourceShape.size ());
1940+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1941+ sourceShape.size (),
1942+ inputScalableVecDims.end ());
1943+ } else {
1944+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1945+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1946+ }
19121947
1913- // vectorSizes is the shape of the vector that will be used to do final
1948+ // writeVectorSizes is the shape of the vector that will be used to do final
19141949 // write on the destination tensor. It is set like this: Let's say the
19151950 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
19161951 // Thus:
1917- // 1. vectorSizes = sourceShape.take_front(N)
1918- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1952+ // 1. writeVectorSizes = sourceShape.take_front(N)
1953+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
19191954 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
19201955 // innerTiles attribute value.
1921- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1922- if (vectorSizes.empty ()) {
1923- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1956+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1957+ if (writeVectorSizes.empty ()) {
1958+ if (ShapedType::isDynamicShape (sourceShape))
1959+ return failure ();
1960+
1961+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
19241962 if (!outerDimsPerm.empty ())
1925- applyPermutationToVector (vectorSizes , outerDimsPerm);
1963+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
19261964 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1927- vectorSizes [pos] *= innerTiles[i];
1965+ writeVectorSizes [pos] *= innerTiles[i];
19281966
19291967 useInBoundsInsteadOfMasking = true ;
19301968 }
@@ -1948,17 +1986,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19481986 // After applying outer_dims_perm: [8, 16]
19491987 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19501988
1951- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1952-
1953- for (auto [index, size] : enumerate(innerTiles)) {
1954- readVectorSizes[innerDimPos[index]] =
1955- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1956- }
1957- if (!outerDimsPerm.empty ()) {
1958- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1989+ if (readVectorSizes.empty ()) {
1990+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1991+ // sizes. Note, this will only work when all sizes are static.
1992+ readVectorSizes = writeVectorSizes;
1993+ for (auto [index, size] : enumerate(innerTiles)) {
1994+ readVectorSizes[innerDimPos[index]] =
1995+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1996+ }
1997+ if (!outerDimsPerm.empty ()) {
1998+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1999+ }
2000+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
2001+ sourceShape.end ());
19592002 }
1960- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1961- sourceShape.end ());
19622003
19632004 Location loc = unpackOp->getLoc ();
19642005
@@ -1970,7 +2011,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19702011 // to shape of source, then a mask is necessary.
19712012 Value readResult = vector::createReadOrMaskedRead (
19722013 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1973- /* useInBoundsInsteadOfMasking=*/ false );
2014+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
19742015
19752016 PackingMetadata packMetadata;
19762017 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -2016,7 +2057,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20162057 assert (succeeded (status) && " failed to reify result shapes" );
20172058 auto maskedRead = vector::createReadOrMaskedRead (
20182059 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
2019- /* useInBoundsInsteadOfMasking=*/ false );
2060+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
20202061
20212062 // Create Xfer write Op
20222063 Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2100,6 +2141,9 @@ static LogicalResult
21002141vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21012142 ArrayRef<int64_t > inputVectorSizes) {
21022143
2144+ // FIXME!!!
2145+ return success ();
2146+
21032147 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
21042148 return !getConstantIntValue (res).has_value ();
21052149 })) {
@@ -2436,6 +2480,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24362480 LDBG () << " pad value is not constant: " << packOp;
24372481 return failure ();
24382482 }
2483+
24392484 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24402485 bool satisfyEmptyCond = true ;
24412486 if (inputVectorSizes.empty ()) {
@@ -2514,12 +2559,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
25142559 if (numOfScalableDims == 0 )
25152560 return success ();
25162561
2562+ // TODO: Check the following!
25172563 auto linalgOp = dyn_cast<LinalgOp>(op);
25182564
2519- // Cond 1: There's been no need for scalable vectorisation of
2520- // non-linalg Ops so far
2521- if (!linalgOp)
2522- return failure ();
2565+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2566+ // exception of UnpackOp for which there is a dedicated hook.
2567+ if (!linalgOp) {
2568+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2569+ }
25232570
25242571 // Cond 2: There's been no need for more than 2 scalable dims so far
25252572 if (numOfScalableDims > 2 )
@@ -2617,7 +2664,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26172664 isa<linalg::MatmulTransposeAOp>(op) ||
26182665 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26192666 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2620- hasReductionIterator (linalgOp));
2667+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
26212668}
26222669
26232670LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2750,7 +2797,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27502797 })
27512798 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27522799 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2753- inputVectorSizes, results);
2800+ inputVectorSizes,
2801+ inputScalableVecDims, results);
27542802 })
27552803 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27562804 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3190,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31423190 vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
31433191 Value read = mlir::vector::createReadOrMaskedRead (
31443192 rewriter, loc, source, vecType.getShape (), padValue,
3145- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3193+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3194+ /* inputScalableVecSizes=*/ {});
31463195
31473196 // Create write
31483197 auto writeIndices =
0 commit comments