-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Vector] Add patterns to lower vector.shuffle
#157611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| //===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file implements the lowering of complex `vector.shuffle` operation to a | ||
| // set of simpler operations supported by LLVM/SPIR-V. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
|
|
||
| #define DEBUG_TYPE "vector-shuffle-lowering" | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::vector; | ||
|
|
||
| namespace { | ||
|
|
||
| /// Lowers a `vector.shuffle` operation with mixed-size inputs to a new | ||
| /// `vector.shuffle` which promotes the smaller input to the larger vector size | ||
| /// and an updated version of the original `vector.shuffle`. | ||
| /// | ||
| /// Example: | ||
| /// | ||
| /// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32> | ||
| /// | ||
| /// is lowered to: | ||
| /// | ||
| /// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] : | ||
| /// vector<2xf32>, vector<2xf32> | ||
| /// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] : | ||
| /// vector<4xf32>, vector<4xf32> | ||
| /// | ||
| /// Note: This transformation helps legalize vector.shuffle ops when lowering | ||
| /// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size | ||
| /// inputs. | ||
| /// | ||
| struct MixedSizeInputShuffleOpRewrite final | ||
| : OpRewritePattern<vector::ShuffleOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto v1Type = shuffleOp.getV1VectorType(); | ||
| auto v2Type = shuffleOp.getV2VectorType(); | ||
|
|
||
| // Only support 1-D shuffle for now. | ||
| if (v1Type.getRank() != 1 || v2Type.getRank() != 1) | ||
| return failure(); | ||
|
|
||
| // Bail out if inputs don't have mixed sizes. | ||
| int64_t v1OrigNumElems = v1Type.getNumElements(); | ||
| int64_t v2OrigNumElems = v2Type.getNumElements(); | ||
| if (v1OrigNumElems == v2OrigNumElems) | ||
| return failure(); | ||
|
|
||
| // Determine which input needs promotion. | ||
| bool promoteV1 = v1OrigNumElems < v2OrigNumElems; | ||
| Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2(); | ||
| VectorType promotedType = promoteV1 ? v2Type : v1Type; | ||
| int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems; | ||
| int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems; | ||
|
|
||
| // Create a shuffle with a mask that preserves existing elements and fills | ||
| // up with poison. | ||
| SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex); | ||
| for (int64_t i = 0; i < origNumElems; ++i) | ||
| promoteMask[i] = i; | ||
|
|
||
| Value promotedInput = rewriter.create<vector::ShuffleOp>( | ||
| shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote, | ||
| promoteMask); | ||
|
|
||
| // Create the final shuffle with the promoted inputs. | ||
| Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1(); | ||
| Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput; | ||
|
|
||
| SmallVector<int64_t> newMask; | ||
| if (!promoteV1) { | ||
| newMask = to_vector(shuffleOp.getMask()); | ||
| } else { | ||
| // Adjust V2 indices to account for the new V1 size. | ||
| for (auto idx : shuffleOp.getMask()) { | ||
| int64_t newIdx = idx; | ||
| if (idx >= v1OrigNumElems) { | ||
| newIdx += promotedNumElems - v1OrigNumElems; | ||
| } | ||
| newMask.push_back(newIdx); | ||
| } | ||
| } | ||
|
|
||
| rewriter.replaceOpWithNewOp<vector::ShuffleOp>( | ||
| shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2, | ||
| newMask); | ||
| return success(); | ||
| } | ||
| }; | ||
| } // namespace | ||
|
|
||
| void mlir::vector::populateVectorShuffleLoweringPatterns( | ||
| RewritePatternSet &patterns, PatternBenefit benefit) { | ||
| patterns.add<MixedSizeInputShuffleOpRewrite>(patterns.getContext(), benefit); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| // RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: func.func @shuffle_smaller_lhs_arbitrary | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<2xf32>, %[[RHS:.*]]: vector<4xf32> | ||
| func.func @shuffle_smaller_lhs_arbitrary(%lhs: vector<2xf32>, %rhs: vector<4xf32>) -> vector<5xf32> { | ||
| // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32> | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32> | ||
| // CHECK: return %[[RESULT]] : vector<5xf32> | ||
| %0 = vector.shuffle %lhs, %rhs [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32> | ||
| return %0 : vector<5xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @shuffle_smaller_rhs_arbitrary | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<4xi32>, %[[RHS:.*]]: vector<2xi32> | ||
| func.func @shuffle_smaller_rhs_arbitrary(%lhs: vector<4xi32>, %rhs: vector<2xi32>) -> vector<6xi32> { | ||
| // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32> | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32> | ||
| // CHECK: return %[[RESULT]] : vector<6xi32> | ||
| %0 = vector.shuffle %lhs, %rhs [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32> | ||
| return %0 : vector<6xi32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @shuffle_smaller_lhs_concat | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<3xf64>, %[[RHS:.*]]: vector<5xf64> | ||
| func.func @shuffle_smaller_lhs_concat(%lhs: vector<3xf64>, %rhs: vector<5xf64>) -> vector<8xf64> { | ||
| // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64> | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64> | ||
| // CHECK: return %[[RESULT]] : vector<8xf64> | ||
| %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64> | ||
| return %0 : vector<8xf64> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @shuffle_smaller_rhs_concat | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<4xi16>, %[[RHS:.*]]: vector<2xi16> | ||
| func.func @shuffle_smaller_rhs_concat(%lhs: vector<4xi16>, %rhs: vector<2xi16>) -> vector<6xi16> { | ||
| // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16> | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16> | ||
| // CHECK: return %[[RESULT]] : vector<6xi16> | ||
| %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16> | ||
| return %0 : vector<6xi16> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Test that shuffles with same size inputs are not modified. | ||
|
|
||
| // CHECK-LABEL: func.func @negative_shuffle_same_input_sizes | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<4xf32>, %[[RHS:.*]]: vector<4xf32> | ||
| func.func @negative_shuffle_same_input_sizes(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<6xf32> { | ||
| // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]] | ||
| // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]] | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32> | ||
| // CHECK: return %[[RESULT]] : vector<6xf32> | ||
| %0 = vector.shuffle %lhs, %rhs [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32> | ||
| return %0 : vector<6xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Test that multi-dimensional shuffles are not modified. | ||
|
|
||
| // CHECK-LABEL: func.func @negative_shuffle_2d_vectors | ||
| // CHECK-SAME: %[[LHS:.*]]: vector<2x4xf32>, %[[RHS:.*]]: vector<3x4xf32> | ||
| func.func @negative_shuffle_2d_vectors(%lhs: vector<2x4xf32>, %rhs: vector<3x4xf32>) -> vector<4x4xf32> { | ||
| // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]] | ||
| // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]] | ||
| // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32> | ||
| // CHECK: return %[[RESULT]] : vector<4x4xf32> | ||
| %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32> | ||
| return %0 : vector<4x4xf32> | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| VscaleRange{vscaleMin, vscaleMax}); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| struct TestVectorShuffleLowering | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : public PassWrapper<TestVectorShuffleLowering, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| OperationPass<func::FuncOp>> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StringRef getArgument() const final { return "test-vector-shuffle-lowering"; } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StringRef getDescription() const final { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return "Test lowering patterns for vector.shuffle with mixed-size inputs"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void runOnOperation() override { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RewritePatternSet patterns(&getContext()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| populateVectorShuffleLoweringPatterns(patterns); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (void)applyPatternsGreedily(getOperation(), std::move(patterns)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad | |
| // and other lower-level Ops (see step 2.1) | |
| %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1] | |
| : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) | |
| // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad | |
| // and other lower-level Ops (see step 2) | |
| %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [4, 1] | |
| : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) | |
| // 2.1. Decompose tiled PackOp into lower-level Ops | |
| %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> | |
| transform.apply_patterns to %func_op_pack { | |
| transform.apply_patterns.linalg.decompose_pack_unpack | |
| transform.apply_patterns.linalg.decompose_pad | |
| } : !transform.op<"func.func"> | |
| transform.apply_patterns to %func_op_pack { | |
| transform.apply_patterns.tensor.fold_tensor_subset_ops | |
| transform.apply_patterns.canonicalization | |
| } : !transform.op<"func.func"> | |
| // 2.1. Decompose tiled UnpackOp into lower-level Ops | |
| %func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> | |
| transform.apply_patterns to %func_op_unpack { | |
| transform.apply_patterns.linalg.decompose_pack_unpack | |
| } : !transform.op<"func.func"> | |
| transform.apply_patterns to %func_op_unpack { | |
| transform.apply_patterns.tensor.fold_tensor_subset_ops | |
| transform.apply_patterns.canonicalization | |
| } : !transform.op<"func.func"> | |
| // 3. Bufferize before lowering to LLVM | |
| %bufferize = transform.bufferization.one_shot_bufferize %module | |
| {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op | |
| // 4. Canonicalize | |
| %func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func"> | |
| transform.apply_patterns to %func_op_bufferized { | |
| transform.apply_patterns.canonicalization | |
| } : !transform.op<"func.func"> |
Basically, I do end-up re-using TD Ops in other places. I very rarely re-use these "test passes". We don't have any official guidelines, but if you don't mind, I'd prefer a TD Op for testing :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the preference but I'd stick to the existing guidelines. I personally don't think adding a TD op for every populate scales better or reduces boilerplate code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd stick to the existing guidelines.
Do we have guidelines for this? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implicit, I guess? We can ask on Discourse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General remark about this PR - "lowering" would imply that some high-level representation is replaced with something lower level. But that's not quite what is happening here ATM, is it?
Even if this is to evolve towards "lowering", I'd choose a name more adequate to the current state of things and then rename later as things evolve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are lowering generic vector.shuffle ops into simpler/more constained versions of them (supported by LLVM/SPIR-V). We follow this approach/naming for other ops (e.g., shape_cast) so I was trying to be consistent with that pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, this:
... is actually more complex than:
:) (the former involves 2 ops rather than one). Yes, the former form allows us to improve the lowering to LLVM/SPIR-V code, but IMHO that's just one possible use case.
Anyway, naming is hard 🤷🏻 If you want to keep the current name, could you expand the comment for
MixSizeInputShuffleOpRewriteto clarify that that's to help with lowering to SPIR-V/LLVM? I just genuinely don't see it as "lowering" (might be just me), so some additional note would help 🙏🏻