@@ -1879,19 +1879,12 @@ static VectorType getCollapsedVecType(VectorType type,
18791879 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
18801880}
18811881
1882- // / Vectorize `linalg.unpack` into :
1882+ // / Vectorize `linalg.unpack` as :
18831883// / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
18841884// /
1885- // / The input-vector-sizes specify both the read and the write vector
1886- // / sizes and are passed as one array covering both operations, i.e.:
1887- // /
1888- // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889- // / \ / \ /
1890- // / read-sizes write-sizes
1891- // /
1892- // / (for brefity, in the diagram,
1893- // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894- // / )
1885+ // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886+ // / for the xfer_read operation). This is sufficient to infer the other vector
1887+ // / sizes required here.
18951888// /
18961889// / If the vector sizes are not provided:
18971890// / * the vector sizes are determined by the operands,
@@ -1914,8 +1907,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19141907 ArrayRef<bool > inputScalableVecDims,
19151908 SmallVectorImpl<Value> &newResults) {
19161909 if (!inputVectorSizes.empty ()) {
1917- assert (inputVectorSizes.size () ==
1918- unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1910+ assert (inputVectorSizes.size () == unpackOp.getSourceRank () &&
19191911 " Invalid number of input vector sizes!" );
19201912 assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
19211913 " Incompatible number of vector sizes and vector scalable flags!" );
@@ -1935,22 +1927,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19351927
19361928 // 1. Obtain vector sizes for the read and write operations.
19371929 SmallVector<int64_t > readVectorSizes;
1938- SmallVector<int64_t > writeVectorSizes;
19391930 SmallVector<bool > readScalableVectorFlags;
1940- SmallVector<bool > writeScalableVectorFlags;
19411931
19421932 if (!inputVectorSizes.empty ()) {
19431933 // CASE 1.1: Vector sizes are user-specified.
19441934 readVectorSizes.assign (inputVectorSizes.begin (),
19451935 inputVectorSizes.begin () + sourceShape.size ());
1946- writeVectorSizes.assign (inputVectorSizes.begin () + sourceShape.size (),
1947- inputVectorSizes.end ());
19481936 readScalableVectorFlags.assign (inputScalableVecDims.begin (),
19491937 inputScalableVecDims.begin () +
19501938 sourceShape.size ());
1951- writeScalableVectorFlags.assign (inputScalableVecDims.begin () +
1952- sourceShape.size (),
1953- inputScalableVecDims.end ());
19541939 } else {
19551940 // CASE 1.2: Vector sizes are inferred from the static input tensor
19561941 // shapes.
@@ -1959,7 +1944,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19591944 return failure ();
19601945
19611946 readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1962- writeVectorSizes.assign (destShape.begin (), destShape.end ());
19631947 useInBoundsInsteadOfMasking = true ;
19641948 }
19651949
@@ -2109,31 +2093,21 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21092093 unpackOp.getSourceType ().hasStaticShape ())
21102094 return success ();
21112095
2112- // The input vector sizes must be equal to:
2113- // * read-vector-rank + write-vector-rank
2096+ // The number of input vector sizes must be equal to:
2097+ // * read-vector-rank
21142098 if (!inputVectorSizes.empty () &&
2115- (inputVectorSizes.size () !=
2116- unpackOp.getDestRank () + unpackOp.getSourceRank ())) {
2099+ (inputVectorSizes.size () != unpackOp.getSourceRank ())) {
21172100 LDBG () << " Incorrect number of input vector sizes" ;
21182101 return failure ();
21192102 }
21202103
21212104 // Check the vector sizes for the read operation.
21222105 if (failed (vector::isValidMaskedInputVector (
2123- unpackOp.getSourceType ().getShape (),
2124- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2106+ unpackOp.getSourceType ().getShape (), inputVectorSizes))) {
21252107 LDBG () << " Invalid vector sizes for the read operation" ;
21262108 return failure ();
21272109 }
21282110
2129- // Check the vector sizes for the write operation.
2130- if (failed (vector::isValidMaskedInputVector (
2131- unpackOp.getDestType ().getShape (),
2132- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2133- LDBG () << " Invalid vector sizes for the write operation" ;
2134- return failure ();
2135- }
2136-
21372111 return success ();
21382112}
21392113
0 commit comments