Skip to content

Commit 90ea49a

Browse files
authored
[ConstantFolding] Generalize constant folding for vector_deinterleave2 to deinterleave3-8. (#168640)
1 parent f2c9c7d commit 90ea49a

File tree

2 files changed

+225
-16
lines changed

2 files changed

+225
-16
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

100755100644
Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,12 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
16661666
case Intrinsic::vector_interleave7:
16671667
case Intrinsic::vector_interleave8:
16681668
case Intrinsic::vector_deinterleave2:
1669+
case Intrinsic::vector_deinterleave3:
1670+
case Intrinsic::vector_deinterleave4:
1671+
case Intrinsic::vector_deinterleave5:
1672+
case Intrinsic::vector_deinterleave6:
1673+
case Intrinsic::vector_deinterleave7:
1674+
case Intrinsic::vector_deinterleave8:
16691675
// Target intrinsics
16701676
case Intrinsic::amdgcn_perm:
16711677
case Intrinsic::amdgcn_wave_reduce_umin:
@@ -4425,31 +4431,42 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,
44254431
return nullptr;
44264432
return ConstantStruct::get(StTy, SinResult, CosResult);
44274433
}
4428-
case Intrinsic::vector_deinterleave2: {
4434+
case Intrinsic::vector_deinterleave2:
4435+
case Intrinsic::vector_deinterleave3:
4436+
case Intrinsic::vector_deinterleave4:
4437+
case Intrinsic::vector_deinterleave5:
4438+
case Intrinsic::vector_deinterleave6:
4439+
case Intrinsic::vector_deinterleave7:
4440+
case Intrinsic::vector_deinterleave8: {
4441+
unsigned NumResults = StTy->getNumElements();
44294442
auto *Vec = Operands[0];
44304443
auto *VecTy = cast<VectorType>(Vec->getType());
44314444

4445+
ElementCount ResultEC =
4446+
VecTy->getElementCount().divideCoefficientBy(NumResults);
4447+
44324448
if (auto *EltC = Vec->getSplatValue()) {
4433-
ElementCount HalfEC = VecTy->getElementCount().divideCoefficientBy(2);
4434-
auto *HalfVec = ConstantVector::getSplat(HalfEC, EltC);
4435-
return ConstantStruct::get(StTy, HalfVec, HalfVec);
4449+
auto *ResultVec = ConstantVector::getSplat(ResultEC, EltC);
4450+
SmallVector<Constant *, 8> Results(NumResults, ResultVec);
4451+
return ConstantStruct::get(StTy, Results);
44364452
}
44374453

4438-
if (!isa<FixedVectorType>(Vec->getType()))
4454+
if (!ResultEC.isFixed())
44394455
return nullptr;
44404456

4441-
unsigned NumElements = VecTy->getElementCount().getFixedValue() / 2;
4442-
SmallVector<Constant *, 4> Res0(NumElements), Res1(NumElements);
4443-
for (unsigned I = 0; I < NumElements; ++I) {
4444-
Constant *Elt0 = Vec->getAggregateElement(2 * I);
4445-
Constant *Elt1 = Vec->getAggregateElement(2 * I + 1);
4446-
if (!Elt0 || !Elt1)
4447-
return nullptr;
4448-
Res0[I] = Elt0;
4449-
Res1[I] = Elt1;
4457+
unsigned NumElements = ResultEC.getFixedValue();
4458+
SmallVector<Constant *, 8> Results(NumResults);
4459+
SmallVector<Constant *> Elements(NumElements);
4460+
for (unsigned I = 0; I != NumResults; ++I) {
4461+
for (unsigned J = 0; J != NumElements; ++J) {
4462+
Constant *Elt = Vec->getAggregateElement(J * NumResults + I);
4463+
if (!Elt)
4464+
return nullptr;
4465+
Elements[J] = Elt;
4466+
}
4467+
Results[I] = ConstantVector::get(Elements);
44504468
}
4451-
return ConstantStruct::get(StTy, ConstantVector::get(Res0),
4452-
ConstantVector::get(Res1));
4469+
return ConstantStruct::get(StTy, Results);
44534470
}
44544471
default:
44554472
// TODO: Constant folding of vector intrinsics that fall through here does

0 commit comments

Comments
 (0)