@@ -250,6 +250,32 @@ struct VectorizationState {
250250 LinalgOp linalgOp,
251251 std::optional<AffineMap> maybeMaskingMap);
252252
253+ // / Check whether this permutation map can be used for masking. At the
254+ // / moment we only make sure that there are no broadcast dimensions, but this
255+ // / might change if indexing maps evolve.
256+ bool isValidMaskingMap (AffineMap maskingMap) {
257+ return maskingMap.getBroadcastDims ().size () == 0 ;
258+ }
259+
260+ // / Turn the input indexing map into a valid masking map.
261+ // /
262+ // / The input indexing map may contain "zero" results, e.g.:
263+ // / (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264+ // / Applying such maps to canonical vector shapes like this one:
265+ // / (1, 16, 16, 4)
266+ // / would yield an invalid vector shape like this:
267+ // / (16, 16, 1, 0)
268+ // / Instead, drop the broadcasting dims that make no sense for masking perm.
269+ // / maps:
270+ // / (d0, d1, d2, d3) -> (d2, d1, d0)
271+ // / This way, the corresponding vector/mask type will be:
272+ // / vector<16x16x1xty>
273+ // / rather than this invalid Vector type:
274+ // / vector<16x16x1x0xty>
275+ AffineMap getMaskingMapFromIndexingMap (AffineMap &indexingMap) {
276+ return indexingMap.dropZeroResults ();
277+ }
278+
253279 // Holds the compile-time static sizes of the iteration space to vectorize.
254280 // Dynamic dimensions are represented using ShapedType::kDynamic.
255281 SmallVector<int64_t > iterSpaceStaticSizes;
@@ -360,6 +386,10 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
360386Value VectorizationState::getOrCreateMaskFor (
361387 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362388 std::optional<AffineMap> maybeMaskingMap) {
389+
390+ assert ((!maybeMaskingMap || isValidMaskingMap (*maybeMaskingMap)) &&
391+ " Ill-formed masking map." );
392+
363393 // No mask is needed if the operation is not maskable.
364394 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365395 if (!maskableOp)
@@ -429,20 +459,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
429459 LDBG (" Trying to mask: " << *opToMask << " \n " );
430460
431461 std::optional<AffineMap> maybeMaskingMap = std::nullopt ;
432- // The Operand indexing map may contain "zero" results, e.g.:
433- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434- // When applied to canonical vector shapes like these:
435- // (1, 16, 16, 4)
436- // we would get:
437- // (1, 16, 16, 0)
438- // Instead, we should extract the following map permutation map for masking:
439- // (d0, d1, d2, d3) -> (d0, d1, d2)
440- // This way, the corresponding vector/mask type will be:
441- // vector<1x16x16xty>
442- // rather than:
443- // vector<1x16x16x0xty>
444462 if (maybeIndexingMap)
445- maybeMaskingMap = maybeIndexingMap-> dropZeroResults ( );
463+ maybeMaskingMap = getMaskingMapFromIndexingMap (*maybeIndexingMap );
446464
447465 // Create or retrieve mask for this operation.
448466 Value mask =
0 commit comments