@@ -223,7 +223,7 @@ struct VectorizationState {
223223 LogicalResult initState (RewriterBase &rewriter, LinalgOp linalgOp,
224224 ArrayRef<int64_t > inputVectorSizes,
225225 ArrayRef<bool > inputScalableVecDims,
226- bool assumeScalableVecSizesMatchDimSize = false );
226+ bool assumeDynamicDimsMatchVecSizes = false );
227227
228228 // / Returns the canonical vector shape used to vectorize the iteration space.
229229 ArrayRef<int64_t > getCanonicalVecShape () const { return canonicalVecShape; }
@@ -333,16 +333,13 @@ struct VectorizationState {
333333 // / when the vectorization state is initialized.
334334 OpBuilder::InsertionGuard rewriterGuard;
335335
336- // / Do all scalable vector sizes match the corresponding input dim sizes?
337- // / (tensor or memref)
336+ // / Do all dynamic dims match the corresponding vector sizes?
338337 // /
339- // / At the Tensor + MemRef levels, scalable sizes are modelled using
340- // / dynamic dimensions (i.e. `?`). In many cases these sizes result from
341- // / e.g. "scalable packing + tiling" and are known to always match the
342- // / scalable vector sizes. In such cases, masking can be safely skipped,
343- // / despite the presence of dynamic shapes. Use this flag with care and
344- // / only for cases where you are confident the assumption holds.
345- bool assumeScalableVecSizesMatchDimSize = false ;
338+ // / When a dynamic tensor/memref dimension matches the corresponding vector
339+ // / dimension, masking can be safely skipped, despite the presence of dynamic
340+ // / shapes. Use this flag with care and only for cases where you are
341+ // / confident the assumption holds.
342+ bool assumeDynamicDimsMatchVecSizes = false ;
346343};
347344
348345LogicalResult
@@ -383,8 +380,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter,
383380 LinalgOp linalgOp,
384381 ArrayRef<int64_t > inputVectorSizes,
385382 ArrayRef<bool > inputScalableVecDims,
386- bool assumeScalableSizes ) {
387- assumeScalableVecSizesMatchDimSize = assumeScalableSizes ;
383+ bool assumeDimsMatchVec ) {
384+ assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec ;
388385 // Initialize the insertion point.
389386 rewriter.setInsertionPoint (linalgOp);
390387
@@ -484,7 +481,7 @@ Value VectorizationState::getOrCreateMaskFor(
484481 return Value ();
485482 }
486483
487- if (assumeScalableVecSizesMatchDimSize ) {
484+ if (assumeDynamicDimsMatchVecSizes ) {
488485 // Given that all _scalable vector sizes_ match the corresponding
489486 // memref/tensor dim sizes, masking can be skipped provided that:
490487 // * all vector sizes corresponding to dynamic dims are scalable.
@@ -2568,7 +2565,7 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25682565FailureOr<VectorizationResult> mlir::linalg::vectorize (
25692566 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t > inputVectorSizes,
25702567 ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract,
2571- bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim ) {
2568+ bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes ) {
25722569 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25732570 LDBG (" Input vector sizes: " );
25742571 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2589,7 +2586,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
25892586 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
25902587 if (failed (state.initState (rewriter, linalgOp, inputVectorSizes,
25912588 inputScalableVecDims,
2592- assumeScalableSizesMultipleOfDim ))) {
2589+ assumeDynamicDimsMatchVecSizes ))) {
25932590 LDBG (" Vectorization state couldn't be initialized\n " );
25942591 return failure ();
25952592 }
0 commit comments