@@ -1841,10 +1841,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18411841// /
18421842// / When collapsing scalable flags, conservatively avoids cases with two
18431843// / scalable dims. We could re-visit this in the future.
1844- // /
1845- // / If the vector sizes are not provided:
1846- // / * the vector sizes are determined by the input operand and attributes,
1847- // / * update the inBounds attribute instead of masking.
18481844static VectorType getCollapsedVecType (VectorType type,
18491845 ArrayRef<AffineMap> reassociation) {
18501846 assert (type.getNumScalableDims () < 2 &&
@@ -1876,22 +1872,35 @@ static VectorType getCollapsedVecType(VectorType type,
18761872 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
18771873}
18781874
1879- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1880- // / Vector::TransferReadOp - Reads a vector from the source tensor
1881- // / vector::TransposeOp - Transpose the Source tensor
1882- // / ShapeCastOp - Reshape the data based on the target.
1883- // / vector::TransferWriteOp. - Write the result vector back to the destination
1884- // / tensor.
1885- // / If the vector sizes are not provided:
1886- // / Vectorize `linalg.unpack %src into %dest` as:
1887- // / // Reads a vector from the source tensor
1888- // / %read = vector.transfer_read %src
1889- // / // Transpose %read as specified in `outer_dims_perm` attribute
1890- // / %tr = vector.transpose %read
1891- // / // Reshape the data based on the target
1892- // / %sc = vector.shape_cast %tr
1893- // / // Write the result vector to the destination tensor.
1894- // / vector.transfer_write %sc into %dest
1875+ // / Vectorize `linalg.unpack` into:
1876+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1877+ // /
1878+ // / The input-vector-sizes specify both the read and the write vector
1879+ // / sizes and are passed as one array covering both operations, i.e.:
1880+ // /
1881+ // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1882+ // / \ / \ /
1883+ // / read-sizes write-sizes
1884+ // /
1885+ // / (for brefity, in the diagram,
1886+ // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1887+ // / )
1888+ // /
1889+ // / If the vector sizes are not provided:
1890+ // / * the vector sizes are determined by the operands,
1891+ // / * the inBounds attribute is used instead of masking.
1892+ // /
1893+ // / EXAMPLE (no vector sizes):
1894+ // / ```
1895+ // / %unpack = linalg.unpack %src
1896+ // / inner_dims_pos = [0, 1]
1897+ // / inner_tiles = [8, 8]
1898+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899+ // / ```
1900+ // / is vectorized as:
1901+ // / ```
1902+ // / vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1903+ // / ```
18951904static LogicalResult
18961905vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18971906 ArrayRef<int64_t > inputVectorSizes,
@@ -1911,22 +1920,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19111920
19121921 RankedTensorType unpackTensorType = unpackOp.getSourceType ();
19131922
1914- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1915- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
19161923 ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1924+ ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
19171925 bool useInBoundsInsteadOfMasking = false ;
1918- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
19191926
1920- auto destSize = unpackOp. getDestRank ();
1927+ Location loc = unpackOp-> getLoc ();
19211928
1922- // 1. Obtain vector sizes for the read and write operation.s
1929+ // 1. Obtain vector sizes for the read and write operations.
19231930 SmallVector<int64_t > readVectorSizes;
19241931 SmallVector<int64_t > writeVectorSizes;
19251932 SmallVector<bool > readScalableVectorFlags;
19261933 SmallVector<bool > writeScalableVectorFlags;
19271934
1928- // CASE 1: Vector sizes are user-specified.
1929- // 1.0 This is the trivial case, simply split the input vector sizes.
1935+ // CASE 1.1: Vector sizes are user-specified.
19301936 if (!inputVectorSizes.empty ()) {
19311937 readVectorSizes.append (inputVectorSizes.begin (),
19321938 inputVectorSizes.begin () + sourceShape.size ());
@@ -1940,83 +1946,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19401946 inputScalableVecDims.end ());
19411947 }
19421948
1943- // CASE 2: Vector sizes have to be inferred.
1944- //
1945- // 1.1 Infer vector sizes for the write operation.
1946- //
1947- // Let:
1948- // * rank(source tensor) = 'M'
1949- // * rank(dest tensor) = 'N',
1950- // and N <= M. The steps are:
1951- // 1. writeVectorSizes = sourceShape.take_front(N)
1952- // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1953- // by the corresponding values from the `inner_tiles` attribute value.
1954- // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1955- //
1956- // Note, this will only work when all sizes are static!
1949+ // CASE 1. 2: Vector sizes have to be inferred.
19571950 if (writeVectorSizes.empty ()) {
1958- if (ShapedType::isDynamicShape (sourceShape))
1951+ if (ShapedType::isDynamicShape (destShape) ||
1952+ ShapedType::isDynamicShape (sourceShape))
19591953 return failure ();
19601954
1961- llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1962- if (!outerDimsPerm.empty ())
1963- applyPermutationToVector (writeVectorSizes, outerDimsPerm);
1964- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1965- writeVectorSizes[pos] *= innerTiles[i];
1966-
1955+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1956+ writeVectorSizes.assign (destShape.begin (), destShape.end ());
19671957 useInBoundsInsteadOfMasking = true ;
19681958 }
19691959
1970- // 1.2 Infer vector sizes for the read operation.
1971- //
1972- // The steps are:
1973- // 1. readVectorSizes = writeVectorSizes
1974- // 2. Take readVectorSizes from 1. and divide all locations pointed by
1975- // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1976- // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1977- // 4. Append the remaining sizes from the source tensor.
1978- //
1979- // Note, this will only work when all sizes are static!
1980- if (readVectorSizes.empty ()) {
1981- readVectorSizes = writeVectorSizes;
1982- for (auto [index, size] : enumerate(innerTiles)) {
1983- readVectorSizes[innerDimPos[index]] =
1984- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1985- }
1986- if (!outerDimsPerm.empty ()) {
1987- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1988- }
1989- readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1990- sourceShape.end ());
1991- }
1992-
1993- Location loc = unpackOp->getLoc ();
1994-
1960+ // 2. Generate the read operation.
19951961 auto padValue = arith::ConstantOp::create (
19961962 rewriter, loc,
19971963 rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
1998-
1999- // Read result, mask if necessary. If transferReadOp shape is not equal
2000- // to shape of source, then a mask is necessary.
20011964 Value readResult = vector::createReadOrMaskedRead (
20021965 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
20031966 /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags);
20041967
1968+ // 3. Generate the transpose operation.
20051969 PackingMetadata packMetadata;
20061970 SmallVector<int64_t > lastDimToInsertPosPerm =
20071971 getUnPackInverseSrcPerm (unpackOp, packMetadata);
2008- // Transpose the appropriate rows to match output.
20091972 vector::TransposeOp transposeOp = vector::TransposeOp::create (
20101973 rewriter, loc, readResult, lastDimToInsertPosPerm);
20111974
2012- // Collapse the vector to the size required by result .
1975+ // 3. Generate the shape_cast operation .
20131976 VectorType collapsedVecType = getCollapsedVecType (
20141977 transposeOp.getType (),
20151978 getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
20161979 rewriter.getContext (), packMetadata.reassociations )));
20171980 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
20181981 rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
20191982
1983+ // 4. Generate the write operation.
20201984 Operation *write = createWriteOrMaskedWrite (
20211985 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
20221986 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
@@ -2144,24 +2108,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21442108 if (!inputVectorSizes.empty ()) {
21452109 if (inputVectorSizes.size () !=
21462110 unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2147- LDBG (" Incorrect number of input vector sizes" ) ;
2111+ LDBG () << " Incorrect number of input vector sizes" ;
21482112 return failure ();
21492113 }
21502114 }
21512115
2152- // Check the vector sizes for the write operation.
2116+ // Check the vector sizes for the read operation.
21532117 if (failed (vector::isValidMaskedInputVector (
2154- unpackOp.getDestType ().getShape (),
2155- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2156- LDBG (" Incorrect number of input vector sizes" ) ;
2118+ unpackOp.getSourceType ().getShape (),
2119+ inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2120+ LDBG () << " Invalid vector sizes for the read operation " ;
21572121 return failure ();
21582122 }
21592123
2160- // Check the vector sizes for the read operation.
2124+ // Check the vector sizes for the write operation.
21612125 if (failed (vector::isValidMaskedInputVector (
2162- unpackOp.getSourceType ().getShape (),
2163- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2164- LDBG (" Incorrect number of input vector sizes" ) ;
2126+ unpackOp.getDestType ().getShape (),
2127+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2128+ LDBG () << " Invalid vector sizes for the write operation " ;
21652129 return failure ();
21662130 }
21672131
@@ -2551,8 +2515,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
25512515 return success ();
25522516}
25532517
2554- // / Preconditions for scalable vectors. This is quite restrictive - it models
2555- // / the fact that in practice we would only make selected dimensions scalable.
2518+ // / Preconditions for scalable vectors.
2519+ // /
2520+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2521+ // / models the fact that in practice we would only make selected dimensions
2522+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2523+ // / unconditionally - we are yet to identify meaningful conditions.
25562524static LogicalResult
25572525vectorizeScalableVectorPrecondition (Operation *op,
25582526 ArrayRef<int64_t > inputVectorSizes,
@@ -2571,7 +2539,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25712539 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
25722540 // exception of UnpackOp for which there is a dedicated hook.
25732541 if (!linalgOp) {
2574- return isa<linalg::UnPackOp>(op) ? success () : failure ( );
2542+ return success ( isa<linalg::UnPackOp>(op));
25752543 }
25762544
25772545 // Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2670,7 +2638,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26702638 isa<linalg::MatmulTransposeAOp>(op) ||
26712639 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26722640 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2673- isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
2641+ hasReductionIterator (linalgOp));
26742642}
26752643
26762644LogicalResult mlir::linalg::vectorizeOpPrecondition (
0 commit comments