@@ -1879,22 +1879,35 @@ static VectorType getCollapsedVecType(VectorType type,
18791879 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
18801880}
18811881
1882- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1883- // / Vector::TransferReadOp - Reads a vector from the source tensor
1884- // / vector::TransposeOp - Transpose the Source tensor
1885- // / ShapeCastOp - Reshape the data based on the target.
1886- // / vector::TransferWriteOp. - Write the result vector back to the destination
1887- // / tensor.
1888- // / If the vector sizes are not provided:
1889- // / Vectorize `linalg.unpack %src into %dest` as:
1890- // / // Reads a vector from the source tensor
1891- // / %read = vector.transfer_read %src
1892- // / // Transpose %read as specified in `outer_dims_perm` attribute
1893- // / %tr = vector.transpose %read
1894- // / // Reshape the data based on the target
1895- // / %sc = vector.shape_cast %tr
1896- // / // Write the result vector to the destination tensor.
1897- // / vector.transfer_write %sc into %dest
1882+ // / Vectorize `linalg.unpack` into:
1883+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884+ // /
1885+ // / The input-vector-sizes specify both the read and the write vector
1886+ // / sizes and are passed as one array covering both operations, i.e.:
1887+ // /
1888+ // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889+ // / \ / \ /
1890+ // / read-sizes write-sizes
1891+ // /
1892+ // / (for brefity, in the diagram,
1893+ // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894+ // / )
1895+ // /
1896+ // / If the vector sizes are not provided:
1897+ // / * the vector sizes are determined by the operands,
1898+ // / * the inBounds attribute is used instead of masking.
1899+ // /
1900+ // / EXAMPLE (no vector sizes):
1901+ // / ```
1902+ // / %unpack = linalg.unpack %src
1903+ // / inner_dims_pos = [0, 1]
1904+ // / inner_tiles = [8, 8]
1905+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1906+ // / ```
1907+ // / is vectorized as:
1908+ // / ```
1909+ // / vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1910+ // / ```
18981911static LogicalResult
18991912vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19001913 ArrayRef<int64_t > inputVectorSizes,
@@ -1914,22 +1927,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19141927
19151928 RankedTensorType unpackTensorType = unpackOp.getSourceType ();
19161929
1917- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1918- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
19191930 ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1931+ ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
19201932 bool useInBoundsInsteadOfMasking = false ;
1921- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
19221933
1923- auto destSize = unpackOp. getDestRank ();
1934+ Location loc = unpackOp-> getLoc ();
19241935
1925- // 1. Obtain vector sizes for the read and write operation.s
1936+ // 1. Obtain vector sizes for the read and write operations.
19261937 SmallVector<int64_t > readVectorSizes;
19271938 SmallVector<int64_t > writeVectorSizes;
19281939 SmallVector<bool > readScalableVectorFlags;
19291940 SmallVector<bool > writeScalableVectorFlags;
19301941
1931- // CASE 1: Vector sizes are user-specified.
1932- // 1.0 This is the trivial case, simply split the input vector sizes.
1942+ // CASE 1.1: Vector sizes are user-specified.
19331943 if (!inputVectorSizes.empty ()) {
19341944 readVectorSizes.append (inputVectorSizes.begin (),
19351945 inputVectorSizes.begin () + sourceShape.size ());
@@ -1943,83 +1953,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19431953 inputScalableVecDims.end ());
19441954 }
19451955
1946- // CASE 2: Vector sizes have to be inferred.
1947- //
1948- // 1.1 Infer vector sizes for the write operation.
1949- //
1950- // Let:
1951- // * rank(source tensor) = 'M'
1952- // * rank(dest tensor) = 'N',
1953- // and N <= M. The steps are:
1954- // 1. writeVectorSizes = sourceShape.take_front(N)
1955- // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1956- // by the corresponding values from the `inner_tiles` attribute value.
1957- // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1958- //
1959- // Note, this will only work when all sizes are static!
1956+ // CASE 1. 2: Vector sizes have to be inferred.
19601957 if (writeVectorSizes.empty ()) {
1961- if (ShapedType::isDynamicShape (sourceShape))
1958+ if (ShapedType::isDynamicShape (destShape) ||
1959+ ShapedType::isDynamicShape (sourceShape))
19621960 return failure ();
19631961
1964- llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1965- if (!outerDimsPerm.empty ())
1966- applyPermutationToVector (writeVectorSizes, outerDimsPerm);
1967- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1968- writeVectorSizes[pos] *= innerTiles[i];
1969-
1962+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1963+ writeVectorSizes.assign (destShape.begin (), destShape.end ());
19701964 useInBoundsInsteadOfMasking = true ;
19711965 }
19721966
1973- // 1.2 Infer vector sizes for the read operation.
1974- //
1975- // The steps are:
1976- // 1. readVectorSizes = writeVectorSizes
1977- // 2. Take readVectorSizes from 1. and divide all locations pointed by
1978- // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1979- // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1980- // 4. Append the remaining sizes from the source tensor.
1981- //
1982- // Note, this will only work when all sizes are static!
1983- if (readVectorSizes.empty ()) {
1984- readVectorSizes = writeVectorSizes;
1985- for (auto [index, size] : enumerate(innerTiles)) {
1986- readVectorSizes[innerDimPos[index]] =
1987- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1988- }
1989- if (!outerDimsPerm.empty ()) {
1990- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1991- }
1992- readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1993- sourceShape.end ());
1994- }
1995-
1996- Location loc = unpackOp->getLoc ();
1997-
1967+ // 2. Generate the read operation.
19981968 auto padValue = arith::ConstantOp::create (
19991969 rewriter, loc,
20001970 rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
2001-
2002- // Read result, mask if necessary. If transferReadOp shape is not equal
2003- // to shape of source, then a mask is necessary.
20041971 Value readResult = vector::createReadOrMaskedRead (
20051972 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
20061973 /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags);
20071974
1975+ // 3. Generate the transpose operation.
20081976 PackingMetadata packMetadata;
20091977 SmallVector<int64_t > lastDimToInsertPosPerm =
20101978 getUnPackInverseSrcPerm (unpackOp, packMetadata);
2011- // Transpose the appropriate rows to match output.
20121979 vector::TransposeOp transposeOp = vector::TransposeOp::create (
20131980 rewriter, loc, readResult, lastDimToInsertPosPerm);
20141981
2015- // Collapse the vector to the size required by result .
1982+ // 3. Generate the shape_cast operation .
20161983 VectorType collapsedVecType = getCollapsedVecType (
20171984 transposeOp.getType (),
20181985 getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
20191986 rewriter.getContext (), packMetadata.reassociations )));
20201987 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
20211988 rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
20221989
1990+ // 4. Generate the write operation.
20231991 Operation *write = createWriteOrMaskedWrite (
20241992 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
20251993 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
@@ -2147,24 +2115,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21472115 if (!inputVectorSizes.empty ()) {
21482116 if (inputVectorSizes.size () !=
21492117 unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2150- LDBG (" Incorrect number of input vector sizes" ) ;
2118+ LDBG () << " Incorrect number of input vector sizes" ;
21512119 return failure ();
21522120 }
21532121 }
21542122
2155- // Check the vector sizes for the write operation.
2123+ // Check the vector sizes for the read operation.
21562124 if (failed (vector::isValidMaskedInputVector (
2157- unpackOp.getDestType ().getShape (),
2158- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2159- LDBG (" Incorrect number of input vector sizes" ) ;
2125+ unpackOp.getSourceType ().getShape (),
2126+ inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2127+ LDBG () << " Invalid vector sizes for the read operation " ;
21602128 return failure ();
21612129 }
21622130
2163- // Check the vector sizes for the read operation.
2131+ // Check the vector sizes for the write operation.
21642132 if (failed (vector::isValidMaskedInputVector (
2165- unpackOp.getSourceType ().getShape (),
2166- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2167- LDBG (" Incorrect number of input vector sizes" ) ;
2133+ unpackOp.getDestType ().getShape (),
2134+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2135+ LDBG () << " Invalid vector sizes for the write operation " ;
21682136 return failure ();
21692137 }
21702138
@@ -2554,8 +2522,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
25542522 return success ();
25552523}
25562524
2557- // / Preconditions for scalable vectors. This is quite restrictive - it models
2558- // / the fact that in practice we would only make selected dimensions scalable.
2525+ // / Preconditions for scalable vectors.
2526+ // /
2527+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2528+ // / models the fact that in practice we would only make selected dimensions
2529+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2530+ // / unconditionally - we are yet to identify meaningful conditions.
25592531static LogicalResult
25602532vectorizeScalableVectorPrecondition (Operation *op,
25612533 ArrayRef<int64_t > inputVectorSizes,
@@ -2574,7 +2546,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25742546 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
25752547 // exception of UnpackOp for which there is a dedicated hook.
25762548 if (!linalgOp) {
2577- return isa<linalg::UnPackOp>(op) ? success () : failure ( );
2549+ return success ( isa<linalg::UnPackOp>(op));
25782550 }
25792551
25802552 // Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2673,7 +2645,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26732645 isa<linalg::MatmulTransposeAOp>(op) ||
26742646 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26752647 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2676- isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
2648+ hasReductionIterator (linalgOp));
26772649}
26782650
26792651LogicalResult mlir::linalg::vectorizeOpPrecondition (
0 commit comments