@@ -1887,7 +1887,7 @@ static VectorType getCollapsedVecType(VectorType type,
18871887// / sizes required here.
18881888// /
18891889// / If the vector sizes are not provided:
1890- // / * the vector sizes are determined by the operands,
1890+ // / * the vector sizes are determined from the input tensor static shape.
18911891// / * the inBounds attribute is used instead of masking.
18921892// /
18931893// / EXAMPLE (no vector sizes):
@@ -1899,7 +1899,14 @@ static VectorType getCollapsedVecType(VectorType type,
18991899// / ```
19001900// / is vectorized as:
19011901// / ```
1902- // / vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1902+ // / %read = vector.transfer_read %src
1903+ // / : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904+ // / %tr = vector.transpose %read, [0, 2, 1, 3]
1905+ // / : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906+ // / %sc = vector.shape_cast %tr
1907+ // / : vector<1x8x1x8xf32> to vector<8x8xf32>
1908+ // / %vector = vector.transfer_write %sc into %dest
1909+ // / : vector<8x8xf32>, tensor<8x8xf32>
19031910// / ```
19041911static LogicalResult
19051912vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
@@ -1920,60 +1927,51 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19201927 RankedTensorType unpackTensorType = unpackOp.getSourceType ();
19211928
19221929 ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1923- ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
19241930 bool useInBoundsInsteadOfMasking = false ;
19251931
19261932 Location loc = unpackOp->getLoc ();
19271933
1928- // 1. Obtain vector sizes for the read and write operations .
1929- SmallVector<int64_t > readVectorSizes;
1930- SmallVector<bool > readScalableVectorFlags;
1934+ // Obtain vector sizes for the read operation .
1935+ SmallVector<int64_t > readVectorSizes (inputVectorSizes) ;
1936+ SmallVector<bool > readScalableVectorFlags (inputScalableVecDims) ;
19311937
1932- if (!inputVectorSizes.empty ()) {
1933- // CASE 1.1: Vector sizes are user-specified.
1934- readVectorSizes.assign (inputVectorSizes.begin (),
1935- inputVectorSizes.begin () + sourceShape.size ());
1936- readScalableVectorFlags.assign (inputScalableVecDims.begin (),
1937- inputScalableVecDims.begin () +
1938- sourceShape.size ());
1939- } else {
1940- // CASE 1.2: Vector sizes are inferred from the static input tensor
1941- // shapes.
1942- if (ShapedType::isDynamicShape (destShape) ||
1943- ShapedType::isDynamicShape (sourceShape))
1938+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939+ if (inputVectorSizes.empty ()) {
1940+ if (ShapedType::isDynamicShape (sourceShape))
19441941 return failure ();
19451942
19461943 readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
19471944 useInBoundsInsteadOfMasking = true ;
19481945 }
19491946
1950- // 2. Generate the read operation.
1947+ // -- Generate the read operation --
19511948 auto padValue = arith::ConstantOp::create (
19521949 rewriter, loc,
19531950 rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
19541951 Value readResult = vector::createReadOrMaskedRead (
19551952 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1956- /* useInBoundsInsteadOfMasking= */ false , readScalableVectorFlags);
1953+ useInBoundsInsteadOfMasking, readScalableVectorFlags);
19571954
1958- // 3. Generate the transpose operation.
1955+ // -- Generate the transpose operation --
19591956 PackingMetadata packMetadata;
19601957 SmallVector<int64_t > lastDimToInsertPosPerm =
19611958 getUnPackInverseSrcPerm (unpackOp, packMetadata);
19621959 vector::TransposeOp transposeOp = vector::TransposeOp::create (
19631960 rewriter, loc, readResult, lastDimToInsertPosPerm);
19641961
1965- // 3. Generate the shape_cast operation.
1962+ // -- Generate the shape_cast operation --
19661963 VectorType collapsedVecType = getCollapsedVecType (
19671964 transposeOp.getType (),
19681965 getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
19691966 rewriter.getContext (), packMetadata.reassociations )));
19701967 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
19711968 rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
19721969
1973- // 4. Generate the write operation.
1970+ // -- Generate the write operation --
19741971 Operation *write = createWriteOrMaskedWrite (
19751972 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
19761973 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1974+
19771975 newResults.push_back (write->getResult (0 ));
19781976 return success ();
19791977}
0 commit comments