@@ -222,9 +222,11 @@ struct VectorizationState {
222222 // / canonical vector shape for vectorization.
223223 LogicalResult initState (RewriterBase &rewriter, LinalgOp linalgOp,
224224 ArrayRef<int64_t > inputVectorSizes,
225- ArrayRef<bool > inputScalableVecDims);
225+ ArrayRef<bool > inputScalableVecDims,
226+ bool assumeScalableVecSizesMatchDimSize = false );
226227
227- // / Returns the canonical vector shape used to vectorize the iteration space.
228+ // / Returns the canonical vector shape used to vectorize the iteration
229+ // / space.
228230 ArrayRef<int64_t > getCanonicalVecShape () const { return canonicalVecShape; }
229231
230232 // / Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
233235
234236 // / Returns a vector type of the provided `elementType` with the canonical
235237 // / vector shape and the corresponding fixed/scalable dimensions bit. If
236- // / `dimPermutation` is provided, the canonical vector dimensions are permuted
237- // / accordingly.
238+ // / `dimPermutation` is provided, the canonical vector dimensions are
239+ // / permuted accordingly.
238240 VectorType getCanonicalVecType (
239241 Type elementType,
240242 std::optional<AffineMap> dimPermutation = std::nullopt ) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
254256 }
255257
256258 // / Masks an operation with the canonical vector mask if the operation needs
257- // / masking. Returns the masked operation or the original operation if masking
258- // / is not needed. If provided, the canonical mask for this operation is
259- // / permuted using `maybeIndexingMap`.
259+ // / masking. Returns the masked operation or the original operation if
260+ // / masking is not needed. If provided, the canonical mask for this
261+ // / operation is permuted using `maybeIndexingMap`.
260262 Operation *
261263 maskOperation (RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
262264 std::optional<AffineMap> maybeIndexingMap = std::nullopt );
@@ -276,15 +278,15 @@ struct VectorizationState {
276278
277279 // / Create or retrieve an existing mask value to mask `opToMask` in the
278280 // / canonical vector iteration space. If `maybeMaskingMap` the mask is
279- // / permuted using that permutation map. If a new mask is created, it will be
280- // / cached for future users.
281+ // / permuted using that permutation map. If a new mask is created, it will
282+ // / be cached for future users.
281283 Value getOrCreateMaskFor (RewriterBase &rewriter, Operation *opToMask,
282284 LinalgOp linalgOp,
283285 std::optional<AffineMap> maybeMaskingMap);
284286
285287 // / Check whether this permutation map can be used for masking. At the
286- // / moment we only make sure that there are no broadcast dimensions, but this
287- // / might change if indexing maps evolve.
288+ // / moment we only make sure that there are no broadcast dimensions, but
289+ // / this might change if indexing maps evolve.
288290 bool isValidMaskingMap (AffineMap maskingMap) {
289291 return maskingMap.getBroadcastDims ().size () == 0 ;
290292 }
@@ -324,13 +326,24 @@ struct VectorizationState {
324326 // / shape.
325327 SmallVector<bool > scalableVecDims;
326328
327- // / Holds the active masks for permutations of the canonical vector iteration
328- // / space.
329+ // / Holds the active masks for permutations of the canonical vector
330+ // / iteration space.
329331 DenseMap<AffineMap, Value> activeMaskCache;
330332
331333 // / Global vectorization guard for the incoming rewriter. It's initialized
332334 // / when the vectorization state is initialized.
333335 OpBuilder::InsertionGuard rewriterGuard;
336+
337+ // / Do all scalable vector sizes match the corresponding input dim sizes?
338+ // / (tensor or memref)
339+ // /
340+ // / At the Tensor + MemRef levels, scalable sizes are modelled using
341+ // / dynamic dimensions (i.e. `?`). In many cases these sizes result from
342+ // / e.g. "scalable packing + tiling" and are known to always match the
343+ // / scalable vector sizes. In such cases, masking can be safely skipped,
344+ // / despite the presence of dynamic shapes. Use this flag with care and
345+ // / only for cases where you are confident the assumption holds.
346+ bool assumeScalableVecSizesMatchDimSize = false ;
334347};
335348
336349LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
367380// / Initializes the vectorization state, including the computation of the
368381// / canonical vector shape for vectorization.
369382// TODO: Move this to the constructor when we can remove the failure cases.
370- LogicalResult
371- VectorizationState::initState (RewriterBase &rewriter, LinalgOp linalgOp,
372- ArrayRef<int64_t > inputVectorSizes,
373- ArrayRef<bool > inputScalableVecDims) {
383+ LogicalResult VectorizationState::initState (RewriterBase &rewriter,
384+ LinalgOp linalgOp,
385+ ArrayRef<int64_t > inputVectorSizes,
386+ ArrayRef<bool > inputScalableVecDims,
387+ bool assumeScalableSizes) {
388+ assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
374389 // Initialize the insertion point.
375390 rewriter.setInsertionPoint (linalgOp);
376391
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
470485 return Value ();
471486 }
472487
488+ if (assumeScalableVecSizesMatchDimSize) {
489+ // Given that all _scalable vector sizes_ match the corresponding
490+ // memref/tensor dim sizes, masking can be skipped provided that:
491+ // * all vector sizes corresponding to dynamic dims are scalable.
492+ if (llvm::all_of (llvm::zip (permutedStaticSizes, maskType.getScalableDims ()),
493+ [](auto it) {
494+ return std::get<0 >(it) == ShapedType::kDynamic
495+ ? std::get<1 >(it)
496+ : false ;
497+ }))
498+ LDBG (" Masking is not needed for masking map: " << maskingMap << " \n " );
499+ activeMaskCache[maskingMap] = Value ();
500+ return Value ();
501+ }
502+
473503 // Permute the iteration space value sizes to compute the mask upper bounds.
474504 SmallVector<Value> upperBounds =
475505 applyPermutationMap (maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
24792509 return success (isElementwise (linalgOp) || isa<linalg::MatmulOp>(op) ||
24802510 isa<linalg::MatmulTransposeAOp>(op) ||
24812511 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2482- isa<linalg::MatvecOp>(op) || hasReductionIterator (linalgOp));
2512+ isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2513+ hasReductionIterator (linalgOp));
24832514}
24842515
24852516LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25352566 tensor::InsertSliceOp>(op);
25362567}
25372568
2538- FailureOr<VectorizationResult>
2539- mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2540- ArrayRef<int64_t > inputVectorSizes,
2541- ArrayRef<bool > inputScalableVecDims,
2542- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2569+ FailureOr<VectorizationResult> mlir::linalg::vectorize (
2570+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t > inputVectorSizes,
2571+ ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract,
2572+ bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
25432573 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25442574 LDBG (" Input vector sizes: " );
25452575 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25592589 VectorizationState state (rewriter);
25602590 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
25612591 if (failed (state.initState (rewriter, linalgOp, inputVectorSizes,
2562- inputScalableVecDims))) {
2592+ inputScalableVecDims,
2593+ assumeScalableSizesMultipleOfDim))) {
25632594 LDBG (" Vectorization state couldn't be initialized\n " );
25642595 return failure ();
25652596 }
0 commit comments