Skip to content

Commit 2b6019c

Browse files
committed
fixup! [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d
Rename the bool to assumeDynamicDimsMatchVecSizes
1 parent 8ef6661 commit 2b6019c

File tree

4 files changed

+15
-18
lines changed

4 files changed

+15
-18
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
24432443
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
24442444
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
24452445
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2446-
OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
2446+
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
24472447
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
24482448

24492449
let results = (outs);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ vectorize(RewriterBase &rewriter, Operation *op,
872872
ArrayRef<int64_t> inputVectorSizes = {},
873873
ArrayRef<bool> inputScalableVecDims = {},
874874
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
875-
bool assumeScalableSizesMultipleOfDim = false);
875+
bool assumeDynamicDimsMatchVecSizes = false);
876876

877877
/// Emit a suitable vector form for a Copy op with fully static shape.
878878
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3922,7 +3922,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39223922
FailureOr<VectorizationResult> vectorResults =
39233923
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
39243924
getVectorizeNdExtract().value_or(false), false,
3925-
getAssumeScalableSizesMatchDimSize().value_or(false));
3925+
getAssumeDynamicDimsMatchVecSizes().value_or(false));
39263926
if (failed(vectorResults)) {
39273927
return mlir::emitSilenceableFailure(target->getLoc())
39283928
<< "Attempted to vectorize, but failed";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ struct VectorizationState {
223223
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224224
ArrayRef<int64_t> inputVectorSizes,
225225
ArrayRef<bool> inputScalableVecDims,
226-
bool assumeScalableVecSizesMatchDimSize = false);
226+
bool assumeDynamicDimsMatchVecSizes = false);
227227

228228
/// Returns the canonical vector shape used to vectorize the iteration space.
229229
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -333,16 +333,13 @@ struct VectorizationState {
333333
/// when the vectorization state is initialized.
334334
OpBuilder::InsertionGuard rewriterGuard;
335335

336-
/// Do all scalable vector sizes match the corresponding input dim sizes?
337-
/// (tensor or memref)
336+
/// Do all dynamic dims match the corresponding vector sizes?
338337
///
339-
/// At the Tensor + MemRef levels, scalable sizes are modelled using
340-
/// dynamic dimensions (i.e. `?`). In many cases these sizes result from
341-
/// e.g. "scalable packing + tiling" and are known to always match the
342-
/// scalable vector sizes. In such cases, masking can be safely skipped,
343-
/// despite the presence of dynamic shapes. Use this flag with care and
344-
/// only for cases where you are confident the assumption holds.
345-
bool assumeScalableVecSizesMatchDimSize = false;
338+
/// When a dynamic tensor/memref dimension matches the corresponding vector
339+
/// dimension, masking can be safely skipped, despite the presence of dynamic
340+
/// shapes. Use this flag with care and only for cases where you are
341+
/// confident the assumption holds.
342+
bool assumeDynamicDimsMatchVecSizes = false;
346343
};
347344

348345
LogicalResult
@@ -383,8 +380,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter,
383380
LinalgOp linalgOp,
384381
ArrayRef<int64_t> inputVectorSizes,
385382
ArrayRef<bool> inputScalableVecDims,
386-
bool assumeScalableSizes) {
387-
assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
383+
bool assumeDimsMatchVec) {
384+
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
388385
// Initialize the insertion point.
389386
rewriter.setInsertionPoint(linalgOp);
390387

@@ -484,7 +481,7 @@ Value VectorizationState::getOrCreateMaskFor(
484481
return Value();
485482
}
486483

487-
if (assumeScalableVecSizesMatchDimSize) {
484+
if (assumeDynamicDimsMatchVecSizes) {
488485
// Given that all _scalable vector sizes_ match the corresponding
489486
// memref/tensor dim sizes, masking can be skipped provided that:
490487
// * all vector sizes corresponding to dynamic dims are scalable.
@@ -2568,7 +2565,7 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25682565
FailureOr<VectorizationResult> mlir::linalg::vectorize(
25692566
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
25702567
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2571-
bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
2568+
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
25722569
LDBG("Attempting to vectorize:\n" << *op << "\n");
25732570
LDBG("Input vector sizes: ");
25742571
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2589,7 +2586,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
25892586
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
25902587
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
25912588
inputScalableVecDims,
2592-
assumeScalableSizesMultipleOfDim))) {
2589+
assumeDynamicDimsMatchVecSizes))) {
25932590
LDBG("Vectorization state couldn't be initialized\n");
25942591
return failure();
25952592
}

0 commit comments

Comments
 (0)