-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][linalg] Update vectorization of linalg.pack #163539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
banach-space
wants to merge
4
commits into
main
Choose a base branch
from
users/banach-space/linalg/vectorize_pack
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, | |
| return success(); | ||
| } | ||
|
|
||
| /// Given a linalg::PackOp, return the `dest` shape before any packing | ||
| /// permutations. | ||
| static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, | ||
| ArrayRef<int64_t> destShape) { | ||
| return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); | ||
| } | ||
|
|
||
| /// Determines whether a mask for xfer_write is trivially "all true" | ||
| /// | ||
| /// Given all the inputs required to generate a mask (mask sizes and shapes), | ||
|
|
@@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, | |
| return mlir::vector::maskOperation(builder, write, maskForWrite); | ||
| } | ||
|
|
||
| /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant | ||
| /// padding value and (3) input vector sizes into: | ||
| /// | ||
| /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds | ||
| /// | ||
| /// As in the following example: | ||
| /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] | ||
| /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> | ||
| /// | ||
| /// This pack would be vectorized to: | ||
| /// | ||
| /// %load = vector.mask %mask { | ||
| /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst | ||
| /// {in_bounds = [true, true, true]} : | ||
| /// tensor<32x7x16xf32>, vector<32x8x16xf32> | ||
| /// } : vector<32x8x16xi1> -> vector<32x8x16xf32> | ||
| /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> | ||
| /// to vector<32x4x2x1x16xf32> | ||
| /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] | ||
| /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> | ||
| /// %write = vector.transfer_write %transpose, | ||
| /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] | ||
| /// {in_bounds = [true, true, true, true, true]} | ||
| /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> | ||
| /// | ||
| /// If the (3) input vector sizes are not provided, the vector sizes are | ||
| /// determined by the result tensor shape and the `in_bounds` | ||
| /// attribute is used instead of masking to mark out-of-bounds accesses. | ||
| /// | ||
| /// NOTE: The input vector sizes specify the dimensions corresponding to the | ||
| /// outer dimensions of the output tensor. The remaining dimensions are | ||
| /// computed based on, e.g., the static inner tiles. | ||
| /// Supporting dynamic inner tiles will require the user to specify the | ||
| /// missing vector sizes. This is left as a TODO. | ||
| static LogicalResult | ||
| vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| SmallVectorImpl<Value> &newResults) { | ||
| // TODO: Introduce a parent class that will handle the insertion point update. | ||
| OpBuilder::InsertionGuard g(rewriter); | ||
| rewriter.setInsertionPoint(packOp); | ||
|
|
||
| Location loc = packOp.getLoc(); | ||
| std::optional<Value> padValue = packOp.getPaddingValue() | ||
| ? std::optional(packOp.getPaddingValue()) | ||
| : std::nullopt; | ||
|
|
||
| // If the input vector sizes are not provided, then the vector sizes are | ||
| // determined by the result tensor shape. In case the vector sizes aren't | ||
| // provided, we update the inBounds attribute instead of masking. | ||
| bool useInBoundsInsteadOfMasking = false; | ||
| if (inputVectorSizes.empty()) { | ||
| ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); | ||
| inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank()); | ||
| useInBoundsInsteadOfMasking = true; | ||
| } | ||
|
|
||
| // Create masked TransferReadOp. | ||
| SmallVector<int64_t> inputShape(inputVectorSizes); | ||
| auto innerTiles = packOp.getStaticInnerTiles(); | ||
| auto innerDimsPos = packOp.getInnerDimsPos(); | ||
| auto outerDimsPerm = packOp.getOuterDimsPerm(); | ||
| if (!outerDimsPerm.empty()) | ||
| applyPermutationToVector(inputShape, | ||
| invertPermutationVector(outerDimsPerm)); | ||
| for (auto [idx, size] : enumerate(innerTiles)) | ||
| inputShape[innerDimsPos[idx]] *= size; | ||
| auto maskedRead = vector::createReadOrMaskedRead( | ||
| rewriter, loc, packOp.getSource(), inputShape, padValue, | ||
| useInBoundsInsteadOfMasking, | ||
| /*inputScalableVecSizes=*/{}); | ||
|
|
||
| // Create ShapeCastOp. | ||
| SmallVector<int64_t> destShape(inputVectorSizes); | ||
| destShape.append(innerTiles.begin(), innerTiles.end()); | ||
| auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), | ||
| packOp.getDestType().getElementType()); | ||
| auto shapeCastOp = | ||
| vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead); | ||
|
|
||
| // Create TransposeOp. | ||
| auto destPermutation = | ||
| invertPermutationVector(getPackInverseDestPerm(packOp)); | ||
| auto transposeOp = vector::TransposeOp::create( | ||
| rewriter, loc, shapeCastOp.getResult(), destPermutation); | ||
|
|
||
| // Create TransferWriteOp. | ||
| Operation *write = createWriteOrMaskedWrite( | ||
| rewriter, loc, transposeOp.getResult(), packOp.getDest()); | ||
| newResults.push_back(write->getResult(0)); | ||
| return success(); | ||
| } | ||
|
|
||
| /// Given the re-associations, "collapses" the input Vector type | ||
| /// | ||
| /// This is similar to CollapseShapeOp::inferCollapsedType with two notable | ||
|
|
@@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type, | |
| return VectorType::get(newShape, type.getElementType(), newScalableFlags); | ||
| } | ||
|
|
||
| /// Vectorize `linalg.pack` as: | ||
| /// * xfer_read -> shape_cast -> transpose -> xfer_write | ||
| /// | ||
| /// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector | ||
| /// sizes for the xfer_write operation). This is sufficient to infer the other | ||
| /// vector sizes required here. | ||
| /// | ||
| /// If the vector sizes are not provided: | ||
| /// * the vector sizes are determined from the destination tensor static shape. | ||
| /// * the inBounds attribute is used instead of masking. | ||
| /// | ||
| /// EXAMPLE (no vector sizes): | ||
| /// ``` | ||
| /// %pack = tensor.pack %src | ||
| /// inner_dims_pos = [2, 1] | ||
| /// inner_tiles = [16, 2] | ||
| /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> | ||
| /// `` | ||
| /// is vectorizes as: | ||
| /// ``` | ||
| /// %read = vector.transfer_read %src | ||
| /// : tensor<32x7x16xf32>, vector<32x8x16xf32> | ||
| /// %sc = vector.shape_cast %read | ||
| /// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> | ||
| /// %tr = vector.transpose %sc, [0, 1, 3, 4, 2] | ||
| /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> | ||
| /// %write = vector.transfer_write %tr into %dest | ||
| /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> | ||
| /// ``` | ||
| static LogicalResult | ||
| vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| SmallVectorImpl<Value> &newResults) { | ||
| if (!inputVectorSizes.empty()) { | ||
| assert(inputVectorSizes.size() == packOp.getDestRank() && | ||
| "Invalid number of input vector sizes!"); | ||
| } | ||
|
|
||
| // TODO: Introduce a parent class that will handle the insertion point update. | ||
| OpBuilder::InsertionGuard g(rewriter); | ||
| rewriter.setInsertionPoint(packOp); | ||
|
|
||
| Location loc = packOp.getLoc(); | ||
| std::optional<Value> padValue = packOp.getPaddingValue() | ||
| ? std::optional(packOp.getPaddingValue()) | ||
| : std::nullopt; | ||
|
|
||
| SmallVector<int64_t> destShape = | ||
| SmallVector<int64_t>(packOp.getDestType().getShape()); | ||
|
|
||
| // This is just a convenience alias to clearly communicate that the input | ||
| // vector sizes determine the _write_ sizes. | ||
| ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes; | ||
|
|
||
| // In the absence of input-vector-sizes, use the _static_ input tensor shape. | ||
| // In addition, use the inBounds attribute instead of masking. | ||
| bool useInBoundsInsteadOfMasking = false; | ||
| if (writeVectorSizes.empty()) { | ||
| if (ShapedType::isDynamicShape(destShape)) | ||
| return rewriter.notifyMatchFailure(packOp, | ||
| "unable to infer vector sizes"); | ||
|
|
||
| writeVectorSizes = destShape; | ||
| useInBoundsInsteadOfMasking = true; | ||
| } | ||
|
|
||
| // Compute vector type for the _read_ opeartion. The required dims are | ||
| // determined based on the _write_ vector sizes. This is done in two | ||
| // steps: | ||
| // 1) Invert the permutation/transposition that's part of the Pack | ||
| // operation. | ||
| // 2) Collapse the tiled sizes/dims to "return" to the unpacked domain. | ||
| PackingMetadata packMetadata; | ||
| auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata); | ||
|
|
||
| SmallVector<int64_t> writeVecSizesUnpermuted(writeVectorSizes); | ||
| applyPermutationToVector(writeVecSizesUnpermuted, destInvPermutation); | ||
|
|
||
| VectorType readVecType = getCollapsedVecType( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is a little hard to follow at first glance IMO |
||
| VectorType::get(writeVecSizesUnpermuted, | ||
| packOp.getType().getElementType()), | ||
| getSymbolLessAffineMaps(convertReassociationIndicesToExprs( | ||
| rewriter.getContext(), packMetadata.reassociations))); | ||
|
|
||
| // Create masked TransferReadOp. | ||
| auto maskedRead = vector::createReadOrMaskedRead( | ||
| rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue, | ||
| useInBoundsInsteadOfMasking, | ||
| /*inputScalableVecSizes=*/{}); | ||
|
|
||
| // Create ShapeCastOp. | ||
| auto expandedVecType = VectorType::get(writeVecSizesUnpermuted, | ||
| packOp.getType().getElementType()); | ||
| auto shapeCastOp = | ||
| vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead); | ||
|
|
||
| // Create TransposeOp. | ||
| auto destPermutation = invertPermutationVector(destInvPermutation); | ||
| auto transposeOp = vector::TransposeOp::create( | ||
| rewriter, loc, shapeCastOp.getResult(), destPermutation); | ||
|
|
||
| // Create TransferWriteOp. | ||
| Operation *write = createWriteOrMaskedWrite( | ||
| rewriter, loc, transposeOp.getResult(), packOp.getDest()); | ||
| newResults.push_back(write->getResult(0)); | ||
| return success(); | ||
| } | ||
|
|
||
| /// Vectorize `linalg.unpack` as: | ||
| /// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write | ||
| /// | ||
| /// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes | ||
| /// for the xfer_read operation). This is sufficient to infer the other vector | ||
| /// sizes required here. | ||
| /// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector | ||
| /// sizes for the xfer_read operation). This is sufficient to infer the other | ||
| /// vector sizes required here. | ||
| /// | ||
| /// If the vector sizes are not provided: | ||
| /// * the vector sizes are determined from the input tensor static shape. | ||
|
|
@@ -1960,7 +1968,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, | |
| // In the absence of input-vector-sizes, use the _static_ input tensor shape. | ||
| if (inputVectorSizes.empty()) { | ||
| if (ShapedType::isDynamicShape(sourceShape)) | ||
| return failure(); | ||
| return rewriter.notifyMatchFailure(unpackOp, | ||
| "Unable to infer vector sizes!"); | ||
|
|
||
| readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); | ||
| useInBoundsInsteadOfMasking = true; | ||
|
|
@@ -2443,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, | |
| ArrayRef<int64_t> inputVectorSizes) { | ||
| auto padValue = packOp.getPaddingValue(); | ||
| Attribute cstAttr; | ||
| // TODO: Relax this condiiton | ||
| if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { | ||
| LDBG() << "pad value is not constant: " << packOp; | ||
| return failure(); | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
picky nit: I'd add a blank line before this comment. It looks easier to me; it is a new chunk of the declaration.