Skip to content

Commit 94a1020

Browse files
committed
[mlir][Vector] Add patterns to lower vector.shuffle
This PR adds patterns to lower `vector.shuffle` with inputs with different vector sizes more efficiently. The current LLVM lowering for these cases degenerates to a sequence of `vector.extract` and `vector.insert` operations. With this PR, the smaller input is promoted to larger vector size by introducing an extra `vector.shuffle`.
1 parent 82acc31 commit 94a1020

File tree

5 files changed

+205
-0
lines changed

5 files changed

+205
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,9 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
293293
int64_t targetRank = 1,
294294
PatternBenefit benefit = 1);
295295

296+
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns,
297+
PatternBenefit benefit = 1);
298+
296299
/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
297300
/// n > 1.
298301
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1010
LowerVectorMultiReduction.cpp
1111
LowerVectorScan.cpp
1212
LowerVectorShapeCast.cpp
13+
LowerVectorShuffle.cpp
1314
LowerVectorStep.cpp
1415
LowerVectorToElements.cpp
1516
LowerVectorToFromElementsToShuffleTree.cpp
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the lowering of complex `vector.shuffle` operation to a
10+
// set of simpler operations supported by LLVM/SPIR-V.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-shuffle-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
26+
/// Lowers a `vector.shuffle` operation with mix-size inputs to a new
27+
/// `vector.shuffle` which promotes the smaller input to the larger vector size
28+
/// and an updated version of the original `vector.shuffle`.
29+
///
30+
/// Example:
31+
///
32+
/// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32>
33+
///
34+
/// is lowered to:
35+
///
36+
/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
37+
/// vector<2xf32>, vector<2xf32>
38+
/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
39+
/// vector<4xf32>, vector<4xf32>
40+
///
41+
struct MixSizeInputShuffleOpRewrite final
42+
: OpRewritePattern<vector::ShuffleOp> {
43+
using OpRewritePattern::OpRewritePattern;
44+
45+
LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
46+
PatternRewriter &rewriter) const override {
47+
auto v1Type = shuffleOp.getV1VectorType();
48+
auto v2Type = shuffleOp.getV2VectorType();
49+
50+
// Only support 1-D shuffle for now.
51+
if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
52+
return failure();
53+
54+
// No mix-size inputs.
55+
int64_t v1OrigNumElems = v1Type.getNumElements();
56+
int64_t v2OrigNumElems = v2Type.getNumElements();
57+
if (v1OrigNumElems == v2OrigNumElems)
58+
return failure();
59+
60+
// Determine which input needs promotion.
61+
bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
62+
Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
63+
VectorType promotedType = promoteV1 ? v2Type : v1Type;
64+
int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
65+
int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
66+
67+
// Create a shuffle with a mask that preserves existing elements and fills
68+
// up with poison.
69+
SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex);
70+
for (int64_t i = 0; i < origNumElems; ++i)
71+
promoteMask[i] = i;
72+
73+
Value promotedInput = rewriter.create<vector::ShuffleOp>(
74+
shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
75+
promoteMask);
76+
77+
// Create the final shuffle with the promoted inputs.
78+
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
79+
Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
80+
81+
SmallVector<int64_t> newMask;
82+
if (!promoteV1) {
83+
newMask = to_vector(shuffleOp.getMask());
84+
} else {
85+
// Adjust V2 indices to account for the new V1 size.
86+
for (auto idx : shuffleOp.getMask()) {
87+
int64_t newIdx = idx;
88+
if (idx >= v1OrigNumElems) {
89+
newIdx += promotedNumElems - v1OrigNumElems;
90+
}
91+
newMask.push_back(newIdx);
92+
}
93+
}
94+
95+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
96+
shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
97+
newMask);
98+
return success();
99+
}
100+
};
101+
} // namespace
102+
103+
void mlir::vector::populateVectorShuffleLoweringPatterns(
104+
RewritePatternSet &patterns, PatternBenefit benefit) {
105+
patterns.add<MixSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
106+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @shuffle_v1_smaller_arbitrary
4+
// CHECK-SAME: %[[V1:.*]]: vector<2xf32>, %[[V2:.*]]: vector<4xf32>
5+
func.func @shuffle_v1_smaller_arbitrary(%v1: vector<2xf32>, %v2: vector<4xf32>) -> vector<5xf32> {
6+
// CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
7+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32>
8+
// CHECK: return %[[RESULT]] : vector<5xf32>
9+
%0 = vector.shuffle %v1, %v2 [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32>
10+
return %0 : vector<5xf32>
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: func.func @shuffle_v2_smaller_arbitrary
16+
// CHECK-SAME: %[[V1:.*]]: vector<4xi32>, %[[V2:.*]]: vector<2xi32>
17+
func.func @shuffle_v2_smaller_arbitrary(%v1: vector<4xi32>, %v2: vector<2xi32>) -> vector<6xi32> {
18+
// CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32>
19+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32>
20+
// CHECK: return %[[RESULT]] : vector<6xi32>
21+
%0 = vector.shuffle %v1, %v2 [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32>
22+
return %0 : vector<6xi32>
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: func.func @shuffle_v1_smaller_concat
28+
// CHECK-SAME: %[[V1:.*]]: vector<3xf64>, %[[V2:.*]]: vector<5xf64>
29+
func.func @shuffle_v1_smaller_concat(%v1: vector<3xf64>, %v2: vector<5xf64>) -> vector<8xf64> {
30+
// CHECK: %[[PROMOTE_V1:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64>
31+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_V1]], %[[V2]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64>
32+
// CHECK: return %[[RESULT]] : vector<8xf64>
33+
%0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64>
34+
return %0 : vector<8xf64>
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: func.func @shuffle_v2_smaller_concat
40+
// CHECK-SAME: %[[V1:.*]]: vector<4xi16>, %[[V2:.*]]: vector<2xi16>
41+
func.func @shuffle_v2_smaller_concat(%v1: vector<4xi16>, %v2: vector<2xi16>) -> vector<6xi16> {
42+
// CHECK: %[[PROMOTE_V2:.*]] = vector.shuffle %[[V2]], %[[V2]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16>
43+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[PROMOTE_V2]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16>
44+
// CHECK: return %[[RESULT]] : vector<6xi16>
45+
%0 = vector.shuffle %v1, %v2 [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16>
46+
return %0 : vector<6xi16>
47+
}
48+
49+
// -----
50+
51+
// Test that shuffles with same size inputs are not modified.
52+
53+
// CHECK-LABEL: func.func @shuffle_same_input_sizes
54+
// CHECK-SAME: %[[V1:.*]]: vector<4xf32>, %[[V2:.*]]: vector<4xf32>
55+
func.func @shuffle_same_input_sizes(%v1: vector<4xf32>, %v2: vector<4xf32>) -> vector<6xf32> {
56+
// CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
57+
// CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
58+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
59+
// CHECK: return %[[RESULT]] : vector<6xf32>
60+
%0 = vector.shuffle %v1, %v2 [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32>
61+
return %0 : vector<6xf32>
62+
}
63+
64+
// -----
65+
66+
// Test that multi-dimensional shuffles are not modified.
67+
68+
// CHECK-LABEL: func.func @shuffle_2d_vectors_no_change
69+
// CHECK-SAME: %[[V1:.*]]: vector<2x4xf32>, %[[V2:.*]]: vector<3x4xf32>
70+
func.func @shuffle_2d_vectors_no_change(%v1: vector<2x4xf32>, %v2: vector<3x4xf32>) -> vector<4x4xf32> {
71+
// CHECK-NOT: vector.shuffle %[[V1]], %[[V1]]
72+
// CHECK-NOT: vector.shuffle %[[V2]], %[[V2]]
73+
// CHECK: %[[RESULT:.*]] = vector.shuffle %[[V1]], %[[V2]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
74+
// CHECK: return %[[RESULT]] : vector<4x4xf32>
75+
%0 = vector.shuffle %v1, %v2 [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32>
76+
return %0 : vector<4x4xf32>
77+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks
994994
VscaleRange{vscaleMin, vscaleMax});
995995
}
996996
};
997+
998+
struct TestVectorShuffleLowering
999+
: public PassWrapper<TestVectorShuffleLowering,
1000+
OperationPass<func::FuncOp>> {
1001+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering)
1002+
1003+
StringRef getArgument() const final { return "test-vector-shuffle-lowering"; }
1004+
StringRef getDescription() const final {
1005+
return "Test lowering patterns for vector.shuffle with mixed-size inputs";
1006+
}
1007+
void runOnOperation() override {
1008+
RewritePatternSet patterns(&getContext());
1009+
populateVectorShuffleLoweringPatterns(patterns);
1010+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
1011+
}
1012+
};
9971013
} // namespace
9981014

9991015
namespace mlir {
@@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() {
10231039

10241040
PassRegistration<TestVectorScanLowering>();
10251041

1042+
PassRegistration<TestVectorShuffleLowering>();
1043+
10261044
PassRegistration<TestVectorDistribution>();
10271045

10281046
PassRegistration<TestVectorExtractStridedSliceLowering>();

0 commit comments

Comments
 (0)