Skip to content

Commit 7bdd88c

Browse files
authored
[mlir][Vector] Add patterns to lower vector.shuffle (#157611)
This PR adds patterns to lower `vector.shuffle` with inputs with different vector sizes more efficiently. The current LLVM lowering for these cases degenerates to a sequence of `vector.extract` and `vector.insert` operations. With this PR, the smaller input is promoted to larger vector size by introducing an extra `vector.shuffle`.
1 parent 87bceae commit 7bdd88c

File tree

5 files changed

+209
-0
lines changed

5 files changed

+209
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,9 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
293293
int64_t targetRank = 1,
294294
PatternBenefit benefit = 1);
295295

296+
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
297+
PatternBenefit benefit = 1);
298+
296299
/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
297300
/// n > 1.
298301
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1010
LowerVectorMultiReduction.cpp
1111
LowerVectorScan.cpp
1212
LowerVectorShapeCast.cpp
13+
LowerVectorShuffle.cpp
1314
LowerVectorStep.cpp
1415
LowerVectorToElements.cpp
1516
LowerVectorToFromElementsToShuffleTree.cpp
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the lowering of complex `vector.shuffle` operation to a
10+
// set of simpler operations supported by LLVM/SPIR-V.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-shuffle-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
26+
/// Lowers a `vector.shuffle` operation with mixed-size inputs to a new
27+
/// `vector.shuffle` which promotes the smaller input to the larger vector size
28+
/// and an updated version of the original `vector.shuffle`.
29+
///
30+
/// Example:
31+
///
32+
/// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32>
33+
///
34+
/// is lowered to:
35+
///
36+
/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
37+
/// vector<2xf32>, vector<2xf32>
38+
/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
39+
/// vector<4xf32>, vector<4xf32>
40+
///
41+
/// Note: This transformation helps legalize vector.shuffle ops when lowering
42+
/// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size
43+
/// inputs.
44+
///
45+
struct MixedSizeInputShuffleOpRewrite final
46+
: OpRewritePattern<vector::ShuffleOp> {
47+
using OpRewritePattern::OpRewritePattern;
48+
49+
LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
50+
PatternRewriter &rewriter) const override {
51+
auto v1Type = shuffleOp.getV1VectorType();
52+
auto v2Type = shuffleOp.getV2VectorType();
53+
54+
// Only support 1-D shuffle for now.
55+
if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
56+
return failure();
57+
58+
// Bail out if inputs don't have mixed sizes.
59+
int64_t v1OrigNumElems = v1Type.getNumElements();
60+
int64_t v2OrigNumElems = v2Type.getNumElements();
61+
if (v1OrigNumElems == v2OrigNumElems)
62+
return failure();
63+
64+
// Determine which input needs promotion.
65+
bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
66+
Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
67+
VectorType promotedType = promoteV1 ? v2Type : v1Type;
68+
int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
69+
int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
70+
71+
// Create a shuffle with a mask that preserves existing elements and fills
72+
// up with poison.
73+
SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex);
74+
for (int64_t i = 0; i < origNumElems; ++i)
75+
promoteMask[i] = i;
76+
77+
Value promotedInput = rewriter.create<vector::ShuffleOp>(
78+
shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
79+
promoteMask);
80+
81+
// Create the final shuffle with the promoted inputs.
82+
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
83+
Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
84+
85+
SmallVector<int64_t> newMask;
86+
if (!promoteV1) {
87+
newMask = to_vector(shuffleOp.getMask());
88+
} else {
89+
// Adjust V2 indices to account for the new V1 size.
90+
for (auto idx : shuffleOp.getMask()) {
91+
int64_t newIdx = idx;
92+
if (idx >= v1OrigNumElems) {
93+
newIdx += promotedNumElems - v1OrigNumElems;
94+
}
95+
newMask.push_back(newIdx);
96+
}
97+
}
98+
99+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
100+
shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
101+
newMask);
102+
return success();
103+
}
104+
};
105+
} // namespace
106+
107+
void mlir::vector::populateVectorShuffleLoweringPatterns(
108+
RewritePatternSet &patterns, PatternBenefit benefit) {
109+
patterns.add<MixedSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
110+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @shuffle_smaller_lhs_arbitrary
4+
// CHECK-SAME: %[[LHS:.*]]: vector<2xf32>, %[[RHS:.*]]: vector<4xf32>
5+
func.func @shuffle_smaller_lhs_arbitrary(%lhs: vector<2xf32>, %rhs: vector<4xf32>) -> vector<5xf32> {
6+
// CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
7+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
8+
// CHECK: return %[[RESULT]] : vector<5xf32>
9+
%0 = vector.shuffle %lhs, %rhs [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
10+
return %0 : vector<5xf32>
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: func.func @shuffle_smaller_rhs_arbitrary
16+
// CHECK-SAME: %[[LHS:.*]]: vector<4xi32>, %[[RHS:.*]]: vector<2xi32>
17+
func.func @shuffle_smaller_rhs_arbitrary(%lhs: vector<4xi32>, %rhs: vector<2xi32>) -> vector<6xi32> {
18+
// CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
19+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
20+
// CHECK: return %[[RESULT]] : vector<6xi32>
21+
%0 = vector.shuffle %lhs, %rhs [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
22+
return %0 : vector<6xi32>
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: func.func @shuffle_smaller_lhs_concat
28+
// CHECK-SAME: %[[LHS:.*]]: vector<3xf64>, %[[RHS:.*]]: vector<5xf64>
29+
func.func @shuffle_smaller_lhs_concat(%lhs: vector<3xf64>, %rhs: vector<5xf64>) -> vector<8xf64> {
30+
// CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
31+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
32+
// CHECK: return %[[RESULT]] : vector<8xf64>
33+
%0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
34+
return %0 : vector<8xf64>
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: func.func @shuffle_smaller_rhs_concat
40+
// CHECK-SAME: %[[LHS:.*]]: vector<4xi16>, %[[RHS:.*]]: vector<2xi16>
41+
func.func @shuffle_smaller_rhs_concat(%lhs: vector<4xi16>, %rhs: vector<2xi16>) -> vector<6xi16> {
42+
// CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
43+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
44+
// CHECK: return %[[RESULT]] : vector<6xi16>
45+
%0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
46+
return %0 : vector<6xi16>
47+
}
48+
49+
// -----
50+
51+
// Test that shuffles with same size inputs are not modified.
52+
53+
// CHECK-LABEL: func.func @negative_shuffle_same_input_sizes
54+
// CHECK-SAME: %[[LHS:.*]]: vector<4xf32>, %[[RHS:.*]]: vector<4xf32>
55+
func.func @negative_shuffle_same_input_sizes(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<6xf32> {
56+
// CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
57+
// CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
58+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
59+
// CHECK: return %[[RESULT]] : vector<6xf32>
60+
%0 = vector.shuffle %lhs, %rhs [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
61+
return %0 : vector<6xf32>
62+
}
63+
64+
// -----
65+
66+
// Test that multi-dimensional shuffles are not modified.
67+
68+
// CHECK-LABEL: func.func @negative_shuffle_2d_vectors
69+
// CHECK-SAME: %[[LHS:.*]]: vector<2x4xf32>, %[[RHS:.*]]: vector<3x4xf32>
70+
func.func @negative_shuffle_2d_vectors(%lhs: vector<2x4xf32>, %rhs: vector<3x4xf32>) -> vector<4x4xf32> {
71+
// CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]]
72+
// CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]]
73+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
74+
// CHECK: return %[[RESULT]] : vector<4x4xf32>
75+
%0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
76+
return %0 : vector<4x4xf32>
77+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks
994994
VscaleRange{vscaleMin, vscaleMax});
995995
}
996996
};
997+
998+
struct TestVectorShuffleLowering
999+
: public PassWrapper<TestVectorShuffleLowering,
1000+
OperationPass<func::FuncOp>> {
1001+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering)
1002+
1003+
StringRef getArgument() const final { return "test-vector-shuffle-lowering"; }
1004+
StringRef getDescription() const final {
1005+
return "Test lowering patterns for vector.shuffle with mixed-size inputs";
1006+
}
1007+
void runOnOperation() override {
1008+
RewritePatternSet patterns(&getContext());
1009+
populateVectorShuffleLoweringPatterns(patterns);
1010+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
1011+
}
1012+
};
9971013
} // namespace
9981014

9991015
namespace mlir {
@@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() {
10231039

10241040
PassRegistration<TestVectorScanLowering>();
10251041

1042+
PassRegistration<TestVectorShuffleLowering>();
1043+
10261044
PassRegistration<TestVectorDistribution>();
10271045

10281046
PassRegistration<TestVectorExtractStridedSliceLowering>();

0 commit comments

Comments
 (0)