@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805 inputShape[innerDimsPos[idx]] *= size;
18061806 auto maskedRead = vector::createReadOrMaskedRead (
18071807 rewriter, loc, packOp.getSource (), inputShape, padValue,
1808- useInBoundsInsteadOfMasking);
1808+ useInBoundsInsteadOfMasking,
1809+ /* inputScalableVecSizes=*/ {});
18091810
18101811 // Create ShapeCastOp.
18111812 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1831,18 +1832,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18311832 return success ();
18321833}
18331834
1834- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1835- // / Vector::TransferReadOp - Reads a vector from the source tensor
1836- // / vector::TransposeOp - Transpose the Source tensor
1837- // / ShapeCastOp - Reshape the data based on the target.
1838- // / vector::TransferWriteOp. - Write the result vector back to the destination
1839- // / tensor.
1840- // / If the vector sizes are not provided:
1835+ // / Vectorize `linalg.unpack %src into %dest` as:
1836+ // / // Reads a vector from the source tensor
1837+ // / %read = vector.transfer_read %src
1838+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1839+ // / %tr = vector.transpose %read
1840+ // / // Reshape the data based on the target
1841+ // / %sc = vector.shape_cast %tr
1842+ // / // Write the result vector to the destination tensor.
1843+ // / vector.transfer_write %sc into %dest
1844+ // /
1845+ // / If the vector sizes are not provided:
18411846// / * the vector sizes are determined by the input operand and attributes,
18421847// / * update the inBounds attribute instead of masking.
18431848static LogicalResult
18441849vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18451850 ArrayRef<int64_t > inputVectorSizes,
1851+ ArrayRef<bool > inputScalableVecDims,
18461852 SmallVectorImpl<Value> &newResults) {
18471853
18481854 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1859,25 +1865,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18591865
18601866 auto destSize = unpackOp.getDestRank ();
18611867
1862- if (!inputVectorSizes.empty ())
1863- assert (inputVectorSizes.size () == destSize &&
1868+ if (!inputVectorSizes.empty ()) {
1869+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
18641870 " Incorrect number of input vector sizes" );
1871+ }
1872+
1873+ SmallVector<bool > readScalableVectorFlags;
1874+ SmallVector<bool > writeScalableVectorFlags;
1875+ SmallVector<int64_t > readVectorSizes;
1876+ SmallVector<int64_t > writeVectorSizes;
18651877
1866- // vectorSizes is the shape of the vector that will be used to do final
1878+ // Split input-vector-sizes into vector sizes for the read and write
1879+ // operations.
1880+ if (!inputVectorSizes.empty ()) {
1881+ readVectorSizes.append (inputVectorSizes.begin (),
1882+ inputVectorSizes.begin () + sourceShape.size ());
1883+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1884+ inputVectorSizes.end ());
1885+ }
1886+ if (!inputScalableVecDims.empty ()) {
1887+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1888+ inputScalableVecDims.begin () +
1889+ sourceShape.size ());
1890+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1891+ sourceShape.size (),
1892+ inputScalableVecDims.end ());
1893+ } else {
1894+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1895+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1896+ }
1897+
1898+ // writeVectorSizes is the shape of the vector that will be used to do final
18671899 // write on the destination tensor. It is set like this: Let's say the
18681900 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18691901 // Thus:
1870- // 1. vectorSizes = sourceShape.take_front(N)
1871- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1902+ // 1. writeVectorSizes = sourceShape.take_front(N)
1903+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
18721904 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18731905 // innerTiles attribute value.
1874- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1875- if (vectorSizes.empty ()) {
1876- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1906+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1907+ if (writeVectorSizes.empty ()) {
1908+ if (ShapedType::isDynamicShape (sourceShape))
1909+ return failure ();
1910+
1911+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
18771912 if (!outerDimsPerm.empty ())
1878- applyPermutationToVector (vectorSizes , outerDimsPerm);
1913+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
18791914 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1880- vectorSizes [pos] *= innerTiles[i];
1915+ writeVectorSizes [pos] *= innerTiles[i];
18811916
18821917 useInBoundsInsteadOfMasking = true ;
18831918 }
@@ -1901,17 +1936,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19011936 // After applying outer_dims_perm: [8, 16]
19021937 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19031938
1904- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1905-
1906- for (auto [index, size] : enumerate(innerTiles)) {
1907- readVectorSizes[innerDimPos[index]] =
1908- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1909- }
1910- if (!outerDimsPerm.empty ()) {
1911- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1939+ if (readVectorSizes.empty ()) {
1940+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1941+ // sizes. Note, this will only work when all sizes are static.
1942+ readVectorSizes = writeVectorSizes;
1943+ for (auto [index, size] : enumerate(innerTiles)) {
1944+ readVectorSizes[innerDimPos[index]] =
1945+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1946+ }
1947+ if (!outerDimsPerm.empty ()) {
1948+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1949+ }
1950+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1951+ sourceShape.end ());
19121952 }
1913- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1914- sourceShape.end ());
19151953
19161954 Location loc = unpackOp->getLoc ();
19171955
@@ -1923,7 +1961,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19231961 // to shape of source, then a mask is necessary.
19241962 Value readResult = vector::createReadOrMaskedRead (
19251963 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1926- /* useInBoundsInsteadOfMasking=*/ false );
1964+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
19271965
19281966 PackingMetadata packMetadata;
19291967 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1942,15 +1980,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19421980 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
19431981 stripMineTensorType, packMetadata.reassociations );
19441982 mlir::VectorType vecCollapsedType =
1945- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1983+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1984+ writeScalableVectorFlags);
19461985 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
19471986 rewriter, loc, vecCollapsedType, transposeOp->getResult (0 ));
19481987
1949- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1988+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19501989 // otherwise the validator complains that the mask size is invalid.
1951- SmallVector<int64_t > writeVectorSizes (
1990+ // FIXME: We should not override write-vector-sizes like this.
1991+ SmallVector<int64_t > writeVectorSizesFinal (
19521992 unpackOp.getDestType ().hasStaticShape ()
1953- ? vectorSizes
1993+ ? writeVectorSizes
19541994 : shapeCastOp.getResultVectorType ().getShape ());
19551995 Operation *write = createWriteOrMaskedWrite (
19561996 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1981,7 +2021,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19812021 assert (succeeded (status) && " failed to reify result shapes" );
19822022 auto maskedRead = vector::createReadOrMaskedRead (
19832023 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1984- /* useInBoundsInsteadOfMasking=*/ false );
2024+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
19852025
19862026 // Create Xfer write Op
19872027 Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2065,6 +2105,9 @@ static LogicalResult
20652105vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20662106 ArrayRef<int64_t > inputVectorSizes) {
20672107
2108+ // FIXME!!!
2109+ return success ();
2110+
20682111 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20692112 return !getConstantIntValue (res).has_value ();
20702113 })) {
@@ -2401,6 +2444,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24012444 LDBG () << " pad value is not constant: " << packOp;
24022445 return failure ();
24032446 }
2447+
24042448 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24052449 bool satisfyEmptyCond = true ;
24062450 if (inputVectorSizes.empty ()) {
@@ -2479,12 +2523,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24792523 if (numOfScalableDims == 0 )
24802524 return success ();
24812525
2526+ // TODO: Check the following!
24822527 auto linalgOp = dyn_cast<LinalgOp>(op);
24832528
2484- // Cond 1: There's been no need for scalable vectorisation of
2485- // non-linalg Ops so far
2486- if (!linalgOp)
2487- return failure ();
2529+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2530+ // exception of UnpackOp for which there is a dedicated hook.
2531+ if (!linalgOp) {
2532+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2533+ }
24882534
24892535 // Cond 2: There's been no need for more than 2 scalable dims so far
24902536 if (numOfScalableDims > 2 )
@@ -2582,7 +2628,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25822628 isa<linalg::MatmulTransposeAOp>(op) ||
25832629 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25842630 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2585- hasReductionIterator (linalgOp));
2631+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
25862632}
25872633
25882634LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2715,7 +2761,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27152761 })
27162762 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27172763 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2718- inputVectorSizes, results);
2764+ inputVectorSizes,
2765+ inputScalableVecDims, results);
27192766 })
27202767 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27212768 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3107,7 +3154,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31073154 vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
31083155 Value read = mlir::vector::createReadOrMaskedRead (
31093156 rewriter, loc, source, vecType.getShape (), padValue,
3110- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3157+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3158+ /* inputScalableVecSizes=*/ {});
31113159
31123160 // Create write
31133161 auto writeIndices =
0 commit comments