Skip to content

Commit 1ec6935

Browse files
committed
fixup! fixup! [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d
Fix the condition that checks whether masks are needed, fix test
1 parent 2b6019c commit 1ec6935

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -482,18 +482,19 @@ Value VectorizationState::getOrCreateMaskFor(
482482
}
483483

484484
if (assumeDynamicDimsMatchVecSizes) {
485-
// Given that all _scalable vector sizes_ match the corresponding
486-
// memref/tensor dim sizes, masking can be skipped provided that:
487-
// * all vector sizes corresponding to dynamic dims are scalable.
488-
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
485+
// While we can _assume_ that for dynamic dim sizes the corresponding
486+
// vector sizes match, we still need to check the static dim sizes to be
487+
// 100% sure that masking is indeed not required.
488+
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
489489
[](auto it) {
490490
return std::get<0>(it) == ShapedType::kDynamic
491-
? std::get<1>(it)
492-
: false;
493-
}))
491+
? true
492+
: std::get<0>(it) == std::get<1>(it);
493+
})) {
494494
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
495-
activeMaskCache[maskingMap] = Value();
496-
return Value();
495+
activeMaskCache[maskingMap] = Value();
496+
return Value();
497+
}
497498
}
498499

499500
// Permute the iteration space value sizes to compute the mask upper bounds.

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
928928
module attributes {transform.with_named_sequence} {
929929
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
930930
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
931-
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
931+
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
932932
transform.yield
933933
}
934934
}

0 commit comments

Comments
 (0)