Skip to content

Commit 647bfbd

Browse files
committed
[VectorCombine] Fold permute of intrinsics into intrinsic of permutes: shuffle(intrinsic, poison/undef) -> intrinsic(shuffle)
Implements #170002
1 parent f42e58f commit 647bfbd

File tree

3 files changed

+241
-3
lines changed

3 files changed

+241
-3
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class VectorCombine {
139139
bool foldShuffleOfSelects(Instruction &I);
140140
bool foldShuffleOfCastops(Instruction &I);
141141
bool foldShuffleOfShuffles(Instruction &I);
142+
bool foldPermuteOfIntrinsic(Instruction &I);
142143
bool foldShuffleOfIntrinsics(Instruction &I);
143144
bool foldShuffleToIdentity(Instruction &I);
144145
bool foldShuffleFromReductions(Instruction &I);
@@ -2960,6 +2961,93 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
29602961
return true;
29612962
}
29622963

2964+
/// Try to convert
2965+
/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
2966+
bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
2967+
Value *V0, *V1;
2968+
ArrayRef<int> Mask;
2969+
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_Value(V1), m_Mask(Mask))))
2970+
return false;
2971+
2972+
// Check for permute
2973+
if (!match(V1, m_Poison()) && !match(V1, m_Undef())) {
2974+
LLVM_DEBUG(dbgs() << "not a permute\n");
2975+
return false;
2976+
}
2977+
2978+
auto *II0 = dyn_cast<IntrinsicInst>(V0);
2979+
if (!II0)
2980+
return false;
2981+
2982+
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2983+
auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
2984+
if (!ShuffleDstTy || !IntrinsicSrcTy)
2985+
return false;
2986+
2987+
// Validate it's a pure permute, mask should only reference the first vector
2988+
unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
2989+
for (int Idx : Mask) {
2990+
if (Idx > 0 && Idx >= (int)NumSrcElts)
2991+
return false;
2992+
}
2993+
2994+
Intrinsic::ID IID = II0->getIntrinsicID();
2995+
if (!isTriviallyVectorizable(IID))
2996+
return false;
2997+
2998+
// Cost analysis
2999+
InstructionCost OldCost =
3000+
TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
3001+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleDstTy,
3002+
IntrinsicSrcTy, Mask, CostKind);
3003+
3004+
SmallVector<Type *> NewArgsTy;
3005+
InstructionCost NewCost = 0;
3006+
3007+
for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3008+
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
3009+
NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3010+
} else {
3011+
auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3012+
auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3013+
ShuffleDstTy->getNumElements());
3014+
NewArgsTy.push_back(ArgTy);
3015+
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3016+
ArgTy, VecTy, Mask, CostKind);
3017+
}
3018+
}
3019+
IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3020+
NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3021+
3022+
LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3023+
<< OldCost << " vs NewCost: " << NewCost << "\n");
3024+
3025+
if (NewCost > OldCost)
3026+
return false;
3027+
3028+
// Transform
3029+
SmallVector<Value *> NewArgs;
3030+
for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3031+
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
3032+
NewArgs.push_back(II0->getArgOperand(I));
3033+
} else {
3034+
Value *Shuf = Builder.CreateShuffleVector(
3035+
II0->getArgOperand(I),
3036+
PoisonValue::get(II0->getArgOperand(I)->getType()), Mask);
3037+
NewArgs.push_back(Shuf);
3038+
Worklist.pushValue(Shuf);
3039+
}
3040+
}
3041+
3042+
Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3043+
3044+
if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
3045+
NewInst->copyIRFlags(II0);
3046+
3047+
replaceValue(I, *NewIntrinsic);
3048+
return true;
3049+
}
3050+
29633051
using InstLane = std::pair<Use *, int>;
29643052

29653053
static InstLane lookThroughShuffles(Use *U, int Lane) {
@@ -4718,6 +4806,8 @@ bool VectorCombine::run() {
47184806
return true;
47194807
if (foldShuffleOfShuffles(I))
47204808
return true;
4809+
if (foldPermuteOfIntrinsic(I))
4810+
return true;
47214811
if (foldShuffleOfIntrinsics(I))
47224812
return true;
47234813
if (foldSelectShuffle(I))
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -passes=vector-combine -S -mtriple=aarch64 %s | FileCheck %s
3+
4+
; This file tests the foldPermuteOfIntrinsic optimization which transforms:
5+
; shuffle(intrinsic(args), poison/undef) -> intrinsic(shuffle(args))
6+
; when the shuffle is a permute (operates on single vector) and cost model
7+
; determines the transformation is beneficial.
8+
9+
;; ============================================================================
10+
;; Positive Tests - Should Optimize
11+
;; ============================================================================
12+
13+
define <4 x i32> @extract_lower_sadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
14+
; CHECK-LABEL: @extract_lower_sadd_sat(
15+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
16+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
17+
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
18+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
19+
;
20+
%sat = call <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
21+
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
22+
ret <4 x i32> %result
23+
}
24+
25+
define <4 x i32> @extract_lower_uadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
26+
; CHECK-LABEL: @extract_lower_uadd_sat(
27+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
28+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
29+
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
30+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
31+
;
32+
%sat = call <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
33+
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
34+
ret <4 x i32> %result
35+
}
36+
37+
define <4 x float> @extract_lower_fma(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
38+
; CHECK-LABEL: @extract_lower_fma(
39+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
40+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[B:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
41+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x float> [[C:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
42+
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[TMP3]])
43+
; CHECK-NEXT: ret <4 x float> [[RESULT]]
44+
;
45+
%fma = call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c)
46+
%result = shufflevector <8 x float> %fma, <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
47+
ret <4 x float> %result
48+
}
49+
50+
define <4 x i32> @extract_lower_abs_should_not_shuffle_scalar(<8 x i32> %v) {
51+
; CHECK-LABEL: @extract_lower_abs_should_not_shuffle_scalar(
52+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
53+
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP1]], i1 false)
54+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
55+
;
56+
%abs = call <8 x i32> @llvm.abs.v8i32(<8 x i32> %v, i1 false)
57+
%result = shufflevector <8 x i32> %abs, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
58+
ret <4 x i32> %result
59+
}
60+
61+
define <2 x i64> @extract_lower_i64(<4 x i64> %v1, <4 x i64> %v2) {
62+
; CHECK-LABEL: @extract_lower_i64(
63+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i64> [[V1:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
64+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[V2:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
65+
; CHECK-NEXT: [[RESULT:%.*]] = call <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64> [[TMP1]], <2 x i64> [[TMP2]])
66+
; CHECK-NEXT: ret <2 x i64> [[RESULT]]
67+
;
68+
%sat = call <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64> %v1, <4 x i64> %v2)
69+
%result = shufflevector <4 x i64> %sat, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
70+
ret <2 x i64> %result
71+
}
72+
73+
define <8 x i16> @extract_lower_i16(<16 x i16> %v1, <16 x i16> %v2) {
74+
; CHECK-LABEL: @extract_lower_i16(
75+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i16> [[V1:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
76+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i16> [[V2:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
77+
; CHECK-NEXT: [[RESULT:%.*]] = call <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16> [[TMP1]], <8 x i16> [[TMP2]])
78+
; CHECK-NEXT: ret <8 x i16> [[RESULT]]
79+
;
80+
%sat = call <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16> %v1, <16 x i16> %v2)
81+
%result = shufflevector <16 x i16> %sat, <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
82+
ret <8 x i16> %result
83+
}
84+
85+
define <4 x i32> @extract_lower_with_undef(<8 x i32> %v1, <8 x i32> %v2) {
86+
; CHECK-LABEL: @extract_lower_with_undef(
87+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
88+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
89+
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
90+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
91+
;
92+
%sat = call <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
93+
%result = shufflevector <8 x i32> %sat, <8 x i32> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
94+
ret <4 x i32> %result
95+
}
96+
97+
;; ============================================================================
98+
;; Negative Tests - Should NOT Optimize
99+
;; ============================================================================
100+
101+
define <4 x i32> @same_size_permute(<4 x i32> %v1, <4 x i32> %v2) {
102+
; CHECK-LABEL: @same_size_permute(
103+
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
104+
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
105+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
106+
;
107+
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
108+
%result = shufflevector <4 x i32> %sat, <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
109+
ret <4 x i32> %result
110+
}
111+
112+
define <4 x i32> @not_a_permute_uses_second_operand(<4 x i32> %v1, <4 x i32> %v2, <4 x i32> %other) {
113+
; CHECK-LABEL: @not_a_permute_uses_second_operand(
114+
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
115+
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> [[OTHER:%.*]], <4 x i32> <i32 0, i32 4, i32 1, i32 5>
116+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
117+
;
118+
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
119+
%result = shufflevector <4 x i32> %sat, <4 x i32> %other, <4 x i32> <i32 0, i32 4, i32 1, i32 5>
120+
ret <4 x i32> %result
121+
}
122+
123+
define <4 x i32> @not_an_intrinsic(<8 x i32> %v1, <8 x i32> %v2) {
124+
; CHECK-LABEL: @not_an_intrinsic(
125+
; CHECK-NEXT: [[ADD:%.*]] = add <8 x i32> [[V1:%.*]], [[V2:%.*]]
126+
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <8 x i32> [[ADD]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
127+
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
128+
;
129+
%add = add <8 x i32> %v1, %v2
130+
%result = shufflevector <8 x i32> %add, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
131+
ret <4 x i32> %result
132+
}
133+
134+
declare <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32>, <8 x i32>)
135+
declare <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32>, <4 x i32>)
136+
declare <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64>, <4 x i64>)
137+
declare <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64>, <2 x i64>)
138+
declare <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16>, <16 x i16>)
139+
declare <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16>, <8 x i16>)
140+
141+
declare <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32>, <8 x i32>)
142+
declare <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32>, <4 x i32>)
143+
144+
declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1 immarg)
145+
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1 immarg)
146+
147+
declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>)
148+
declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)

llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ define <8 x i8> @abs_different(<8 x i8> %a) {
204204

205205
define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
206206
; CHECK-LABEL: @poison_intrinsic(
207-
; CHECK-NEXT: [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
208-
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
209-
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
207+
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266:%.*]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
208+
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[L267]], i1 false)
209+
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
210210
; CHECK-NEXT: ret <4 x i32> [[L271]]
211211
;
212212
%l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)

0 commit comments

Comments
 (0)