-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
487db47
8ef6661
2b6019c
1ec6935
4ca2b53
010c822
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some changes on comments are not relevant, can you revert them?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how that crept it. Fixed in this commit. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,9 +222,11 @@ struct VectorizationState { | |
| /// canonical vector shape for vectorization. | ||
| LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| ArrayRef<bool> inputScalableVecDims); | ||
| ArrayRef<bool> inputScalableVecDims, | ||
| bool assumeScalableVecSizesMatchDimSize = false); | ||
|
|
||
| /// Returns the canonical vector shape used to vectorize the iteration space. | ||
| /// Returns the canonical vector shape used to vectorize the iteration | ||
| /// space. | ||
| ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } | ||
|
|
||
| /// Returns the vector dimensions that are scalable in the canonical vector | ||
|
|
@@ -233,8 +235,8 @@ struct VectorizationState { | |
|
|
||
| /// Returns a vector type of the provided `elementType` with the canonical | ||
| /// vector shape and the corresponding fixed/scalable dimensions bit. If | ||
| /// `dimPermutation` is provided, the canonical vector dimensions are permuted | ||
| /// accordingly. | ||
| /// `dimPermutation` is provided, the canonical vector dimensions are | ||
| /// permuted accordingly. | ||
| VectorType getCanonicalVecType( | ||
| Type elementType, | ||
| std::optional<AffineMap> dimPermutation = std::nullopt) const { | ||
|
|
@@ -254,9 +256,9 @@ struct VectorizationState { | |
| } | ||
|
|
||
| /// Masks an operation with the canonical vector mask if the operation needs | ||
| /// masking. Returns the masked operation or the original operation if masking | ||
| /// is not needed. If provided, the canonical mask for this operation is | ||
| /// permuted using `maybeIndexingMap`. | ||
| /// masking. Returns the masked operation or the original operation if | ||
| /// masking is not needed. If provided, the canonical mask for this | ||
| /// operation is permuted using `maybeIndexingMap`. | ||
| Operation * | ||
| maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, | ||
| std::optional<AffineMap> maybeIndexingMap = std::nullopt); | ||
|
|
@@ -276,15 +278,15 @@ struct VectorizationState { | |
|
|
||
| /// Create or retrieve an existing mask value to mask `opToMask` in the | ||
| /// canonical vector iteration space. If `maybeMaskingMap` the mask is | ||
| /// permuted using that permutation map. If a new mask is created, it will be | ||
| /// cached for future users. | ||
| /// permuted using that permutation map. If a new mask is created, it will | ||
| /// be cached for future users. | ||
| Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, | ||
| LinalgOp linalgOp, | ||
| std::optional<AffineMap> maybeMaskingMap); | ||
|
|
||
| /// Check whether this permutation map can be used for masking. At the | ||
| /// moment we only make sure that there are no broadcast dimensions, but this | ||
| /// might change if indexing maps evolve. | ||
| /// moment we only make sure that there are no broadcast dimensions, but | ||
| /// this might change if indexing maps evolve. | ||
| bool isValidMaskingMap(AffineMap maskingMap) { | ||
| return maskingMap.getBroadcastDims().size() == 0; | ||
| } | ||
|
|
@@ -324,13 +326,24 @@ struct VectorizationState { | |
| /// shape. | ||
| SmallVector<bool> scalableVecDims; | ||
|
|
||
| /// Holds the active masks for permutations of the canonical vector iteration | ||
| /// space. | ||
| /// Holds the active masks for permutations of the canonical vector | ||
| /// iteration space. | ||
| DenseMap<AffineMap, Value> activeMaskCache; | ||
|
|
||
| /// Global vectorization guard for the incoming rewriter. It's initialized | ||
| /// when the vectorization state is initialized. | ||
| OpBuilder::InsertionGuard rewriterGuard; | ||
|
|
||
| /// Do all scalable vector sizes match the corresponding input dim sizes? | ||
| /// (tensor or memref) | ||
| /// | ||
| /// At the Tensor + MemRef levels, scalable sizes are modelled using | ||
| /// dynamic dimensions (i.e. `?`). In many cases these sizes result from | ||
| /// e.g. "scalable packing + tiling" and are known to always match the | ||
| /// scalable vector sizes. In such cases, masking can be safely skipped, | ||
| /// despite the presence of dynamic shapes. Use this flag with care and | ||
| /// only for cases where you are confident the assumption holds. | ||
| bool assumeScalableVecSizesMatchDimSize = false; | ||
|
||
| }; | ||
|
|
||
| LogicalResult | ||
|
|
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, | |
| /// Initializes the vectorization state, including the computation of the | ||
| /// canonical vector shape for vectorization. | ||
| // TODO: Move this to the constructor when we can remove the failure cases. | ||
| LogicalResult | ||
| VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| ArrayRef<bool> inputScalableVecDims) { | ||
| LogicalResult VectorizationState::initState(RewriterBase &rewriter, | ||
| LinalgOp linalgOp, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| ArrayRef<bool> inputScalableVecDims, | ||
| bool assumeScalableSizes) { | ||
| assumeScalableVecSizesMatchDimSize = assumeScalableSizes; | ||
| // Initialize the insertion point. | ||
| rewriter.setInsertionPoint(linalgOp); | ||
|
|
||
|
|
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor( | |
| return Value(); | ||
| } | ||
|
|
||
| if (assumeScalableVecSizesMatchDimSize) { | ||
| // Given that all _scalable vector sizes_ match the corresponding | ||
| // memref/tensor dim sizes, masking can be skipped provided that: | ||
| // * all vector sizes corresponding to dynamic dims are scalable. | ||
| if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()), | ||
| [](auto it) { | ||
| return std::get<0>(it) == ShapedType::kDynamic | ||
| ? std::get<1>(it) | ||
| : false; | ||
|
||
| })) | ||
| LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); | ||
|
||
| activeMaskCache[maskingMap] = Value(); | ||
| return Value(); | ||
| } | ||
|
|
||
| // Permute the iteration space value sizes to compute the mask upper bounds. | ||
| SmallVector<Value> upperBounds = | ||
| applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes)); | ||
|
|
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op, | |
| return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || | ||
| isa<linalg::MatmulTransposeAOp>(op) || | ||
| isa<linalg::DepthwiseConv1DNwcWcOp>(op) || | ||
| isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp)); | ||
| isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering if there is a particular reason why this wouldn't work for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should. I am not enabling it just yet to keep this PR relatively small. Adding
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, makes sense. |
||
| hasReductionIterator(linalgOp)); | ||
| } | ||
|
|
||
| LogicalResult mlir::linalg::vectorizeOpPrecondition( | ||
|
|
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { | |
| tensor::InsertSliceOp>(op); | ||
| } | ||
|
|
||
| FailureOr<VectorizationResult> | ||
| mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| ArrayRef<bool> inputScalableVecDims, | ||
| bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { | ||
| FailureOr<VectorizationResult> mlir::linalg::vectorize( | ||
| RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes, | ||
| ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, | ||
| bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) { | ||
| LDBG("Attempting to vectorize:\n" << *op << "\n"); | ||
| LDBG("Input vector sizes: "); | ||
| LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); | ||
|
|
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
| VectorizationState state(rewriter); | ||
| if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { | ||
| if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, | ||
| inputScalableVecDims))) { | ||
| inputScalableVecDims, | ||
| assumeScalableSizesMultipleOfDim))) { | ||
| LDBG("Vectorization state couldn't be initialized\n"); | ||
| return failure(); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have
assumeScalableSizesMultipleOfDimandgetAssumeScalableSizesMatchDimSize. Should we have only one?Also, it should be "dim multiple of vector sizes"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, I went with
assumeDynamicDimsMatchVecSizes, see this commit. As mentioned in my other comment, for now I am "assuming" equality rather than divisibility. Not sure whether we will need the latter?