@@ -2099,24 +2099,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
20992099 return success ();
21002100}
21012101
2102- // / Need to check if the inner-tiles are static/constant.
2102+ // // This hook considers two cases:
2103+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2104+ // / infered. This is only possible when all shapes are static.
2105+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2106+ // / carry out basic sanity-checking.
21032107static LogicalResult
21042108vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21052109 ArrayRef<int64_t > inputVectorSizes) {
2110+ // If there are no input vector sizes and all shapes are static, there is
2111+ // nothing left to check.
2112+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2113+ unpackOp.getSourceType ().hasStaticShape ())
2114+ return success ();
21062115
2107- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2108- return !getConstantIntValue (res).has_value ();
2109- })) {
2110- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2116+ // The input vector sizes must be equal to:
2117+ // * read-vector-rank + write-vector-rank
2118+ if (!inputVectorSizes.empty ()) {
2119+ if (inputVectorSizes.size () !=
2120+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2121+ LDBG (" Incorrect number of input vector sizes" );
2122+ return failure ();
2123+ }
2124+ }
2125+
2126+ // Check the vector sizes for the write operation.
2127+ if (failed (vector::isValidMaskedInputVector (
2128+ unpackOp.getDestType ().getShape (),
2129+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2130+ LDBG (" Incorrect number of input vector sizes" );
21112131 return failure ();
21122132 }
2113- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2114- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2115- unpackOp. getDestType (). hasStaticShape () &&
2116- unpackOp.getSourceType ().hasStaticShape ();
2117- if (!satisfyEmptyCond &&
2118- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2133+
2134+ // Check the vector sizes for the read operation.
2135+ if ( failed ( vector::isValidMaskedInputVector (
2136+ unpackOp.getSourceType ().getShape (),
2137+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2138+ LDBG ( " Incorrect number of input vector sizes " );
21192139 return failure ();
2140+ }
21202141
21212142 return success ();
21222143}
0 commit comments