@@ -2135,24 +2135,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
21352135 return success ();
21362136}
21372137
2138- // / Need to check if the inner-tiles are static/constant.
2138+ // // This hook considers two cases:
2139+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2140+ // / infered. This is only possible when all shapes are static.
2141+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2142+ // / carry out basic sanity-checking.
21392143static LogicalResult
21402144vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21412145 ArrayRef<int64_t > inputVectorSizes) {
2146+ // If there are no input vector sizes and all shapes are static, there is
2147+ // nothing left to check.
2148+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2149+ unpackOp.getSourceType ().hasStaticShape ())
2150+ return success ();
21422151
2143- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2144- return !getConstantIntValue (res).has_value ();
2145- })) {
2146- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2152+ // The input vector sizes must be equal to:
2153+ // * read-vector-rank + write-vector-rank
2154+ if (!inputVectorSizes.empty ()) {
2155+ if (inputVectorSizes.size () !=
2156+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2157+ LDBG (" Incorrect number of input vector sizes" );
2158+ return failure ();
2159+ }
2160+ }
2161+
2162+ // Check the vector sizes for the write operation.
2163+ if (failed (vector::isValidMaskedInputVector (
2164+ unpackOp.getDestType ().getShape (),
2165+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2166+ LDBG (" Incorrect number of input vector sizes" );
21472167 return failure ();
21482168 }
2149- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2150- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2151- unpackOp. getDestType (). hasStaticShape () &&
2152- unpackOp.getSourceType ().hasStaticShape ();
2153- if (!satisfyEmptyCond &&
2154- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2169+
2170+ // Check the vector sizes for the read operation.
2171+ if ( failed ( vector::isValidMaskedInputVector (
2172+ unpackOp.getSourceType ().getShape (),
2173+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2174+ LDBG ( " Incorrect number of input vector sizes " );
21552175 return failure ();
2176+ }
21562177
21572178 return success ();
21582179}
0 commit comments