diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 09c6b2683b438..8925fec034cb9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -250,6 +250,32 @@ struct VectorizationState { LinalgOp linalgOp, std::optional maybeMaskingMap); + /// Check whether this permutation map can be used for masking. At the + /// moment we only make sure that there are no broadcast dimensions, but this + /// might change if indexing maps evolve. + bool isValidMaskingMap(AffineMap maskingMap) { + return maskingMap.getBroadcastDims().size() == 0; + } + + /// Turn the input indexing map into a valid masking map. + /// + /// The input indexing map may contain "zero" results, e.g.: + /// (d0, d1, d2, d3) -> (d2, d1, d0, 0) + /// Applying such maps to canonical vector shapes like this one: + /// (1, 16, 16, 4) + /// would yield an invalid vector shape like this: + /// (16, 16, 1, 0) + /// Instead, drop the broadcasting dims that make no sense for masking perm. + /// maps: + /// (d0, d1, d2, d3) -> (d2, d1, d0) + /// This way, the corresponding vector/mask type will be: + /// vector<16x16x1xty> + /// rather than this invalid Vector type: + /// vector<16x16x1x0xty> + AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) { + return indexingMap.dropZeroResults(); + } + // Holds the compile-time static sizes of the iteration space to vectorize. // Dynamic dimensions are represented using ShapedType::kDynamic. SmallVector iterSpaceStaticSizes; @@ -360,6 +386,10 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, Value VectorizationState::getOrCreateMaskFor( RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional maybeMaskingMap) { + + assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) && + "Ill-formed masking map."); + // No mask is needed if the operation is not maskable. auto maskableOp = dyn_cast(opToMask); if (!maskableOp) @@ -429,20 +459,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, LDBG("Trying to mask: " << *opToMask << "\n"); std::optional maybeMaskingMap = std::nullopt; - // The Operand indexing map may contain "zero" results, e.g.: - // (d0, d1, d2, d3) -> (d0, d1, d2, 0) - // When applied to canonical vector shapes like these: - // (1, 16, 16, 4) - // we would get: - // (1, 16, 16, 0) - // Instead, we should extract the following map permutation map for masking: - // (d0, d1, d2, d3) -> (d0, d1, d2) - // This way, the corresponding vector/mask type will be: - // vector<1x16x16xty> - // rather than: - // vector<1x16x16x0xty> if (maybeIndexingMap) - maybeMaskingMap = maybeIndexingMap->dropZeroResults(); + maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap); // Create or retrieve mask for this operation. Value mask =