Skip to content

Commit 1207064

Browse files
committed
[VectorCombine] Handle shuffle of selects
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) -> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m)) The behaviour of SelectInst on vectors is the same as for `V'select[i] = Condition[i] ? V'True[i] : V'False[i]`. If a ShuffleVector is performed on two selects, it will be like: `V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])` That's why a ShuffleVector with two SelectInst is equivalent to first ShuffleVector Condition/True/False and then SelectInst that result. This patch implements the transforming described above. Proof: https://alive2.llvm.org/ce/z/97wfHp Fixed: #120775
1 parent 2a4dfdf commit 1207064

File tree

3 files changed

+373
-101
lines changed

3 files changed

+373
-101
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class VectorCombine {
119119
bool foldConcatOfBoolMasks(Instruction &I);
120120
bool foldPermuteOfBinops(Instruction &I);
121121
bool foldShuffleOfBinops(Instruction &I);
122+
bool foldShuffleOfSelects(Instruction &I);
122123
bool foldShuffleOfCastops(Instruction &I);
123124
bool foldShuffleOfShuffles(Instruction &I);
124125
bool foldShuffleOfIntrinsics(Instruction &I);
@@ -1899,6 +1900,56 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
18991900
return true;
19001901
}
19011902

1903+
/// Try to convert,
1904+
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
1905+
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
1906+
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
1907+
ArrayRef<int> Mask;
1908+
Value *C1, *T1, *F1, *C2, *T2, *F2;
1909+
if (!match(&I, m_Shuffle(
1910+
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
1911+
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
1912+
m_Mask(Mask))))
1913+
return false;
1914+
1915+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
1916+
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
1917+
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
1918+
if (!C1VecTy || !C2VecTy)
1919+
return false;
1920+
1921+
auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
1922+
auto SelOp = Instruction::Select;
1923+
InstructionCost OldCost = TTI.getCmpSelInstrCost(
1924+
SelOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
1925+
OldCost += TTI.getCmpSelInstrCost(SelOp, T2->getType(), C2VecTy,
1926+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1927+
OldCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
1928+
{I.getOperand(0), I.getOperand(1)}, &I);
1929+
1930+
auto *C1C2VecTy = cast<FixedVectorType>(
1931+
toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
1932+
InstructionCost NewCost =
1933+
TTI.getShuffleCost(SK, C1C2VecTy, Mask, CostKind, 0, nullptr, {C1, C2});
1934+
NewCost +=
1935+
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {T1, T2});
1936+
NewCost +=
1937+
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {F1, F2});
1938+
NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, DstVecTy,
1939+
CmpInst::BAD_ICMP_PREDICATE, CostKind);
1940+
1941+
if (NewCost > OldCost)
1942+
return false;
1943+
1944+
Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
1945+
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
1946+
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
1947+
Value *NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
1948+
1949+
replaceValue(I, *NewSel);
1950+
return true;
1951+
}
1952+
19021953
/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
19031954
/// into "castop (shuffle)".
19041955
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
@@ -3352,6 +3403,7 @@ bool VectorCombine::run() {
33523403
case Instruction::ShuffleVector:
33533404
MadeChange |= foldPermuteOfBinops(I);
33543405
MadeChange |= foldShuffleOfBinops(I);
3406+
MadeChange |= foldShuffleOfSelects(I);
33553407
MadeChange |= foldShuffleOfCastops(I);
33563408
MadeChange |= foldShuffleOfShuffles(I);
33573409
MadeChange |= foldShuffleOfIntrinsics(I);

llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,18 @@ define <8 x i8> @icmpsel(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
451451

452452
define <8 x i8> @icmpsel_diffentcond(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
453453
; CHECK-LABEL: @icmpsel_diffentcond(
454-
; CHECK-NEXT: [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
455-
; CHECK-NEXT: [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
456-
; CHECK-NEXT: [[BB:%.*]] = shufflevector <8 x i8> [[B:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
457-
; CHECK-NEXT: [[BT:%.*]] = shufflevector <8 x i8> [[B]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
458454
; CHECK-NEXT: [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
459455
; CHECK-NEXT: [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
460456
; CHECK-NEXT: [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
461457
; CHECK-NEXT: [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
462-
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[AT]], [[BT]]
463-
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[AB]], [[BB]]
464-
; CHECK-NEXT: [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
465-
; CHECK-NEXT: [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
458+
; CHECK-NEXT: [[CB1:%.*]] = shufflevector <8 x i8> [[C1:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
459+
; CHECK-NEXT: [[CT1:%.*]] = shufflevector <8 x i8> [[C1]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
460+
; CHECK-NEXT: [[DB1:%.*]] = shufflevector <8 x i8> [[D1:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
461+
; CHECK-NEXT: [[DT1:%.*]] = shufflevector <8 x i8> [[D1]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
462+
; CHECK-NEXT: [[ABT1:%.*]] = icmp slt <4 x i8> [[CT]], [[DT]]
463+
; CHECK-NEXT: [[ABB1:%.*]] = icmp ult <4 x i8> [[CB]], [[DB]]
464+
; CHECK-NEXT: [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT1]], <4 x i8> [[DT1]]
465+
; CHECK-NEXT: [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB1]], <4 x i8> [[DB1]]
466466
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
467467
; CHECK-NEXT: ret <8 x i8> [[R]]
468468
;
@@ -996,10 +996,10 @@ define <4 x i64> @bitcast_smax_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b) {
996996
; CHECK-LABEL: @bitcast_smax_v8i32_v4i32(
997997
; CHECK-NEXT: [[A_BC0:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
998998
; CHECK-NEXT: [[B_BC0:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
999-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
1000-
; CHECK-NEXT: [[A_BC1:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
1001-
; CHECK-NEXT: [[B_BC1:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
1002-
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[B_BC1]], <8 x i32> [[A_BC1]]
999+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
1000+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
1001+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
1002+
; CHECK-NEXT: [[CONCAT:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP3]], <8 x i32> [[TMP5]]
10031003
; CHECK-NEXT: [[RES:%.*]] = bitcast <8 x i32> [[CONCAT]] to <4 x i64>
10041004
; CHECK-NEXT: ret <4 x i64> [[RES]]
10051005
;

0 commit comments

Comments
 (0)