diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 64a0f4641250c..2b7a438a9ef01 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3990,31 +3990,30 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID, return ConstantStruct::get(StTy, SinResult, CosResult); } case Intrinsic::vector_deinterleave2: { - auto *Vec = dyn_cast(Operands[0]); - if (!Vec) - return nullptr; - + auto *Vec = Operands[0]; auto *VecTy = cast(Vec->getType()); - unsigned NumElements = VecTy->getElementCount().getKnownMinValue() / 2; - if (isa(Vec)) { - auto *HalfVecTy = VectorType::getHalfElementsVectorType(VecTy); - return ConstantStruct::get(StTy, ConstantAggregateZero::get(HalfVecTy), - ConstantAggregateZero::get(HalfVecTy)); + + if (auto *EltC = Vec->getSplatValue()) { + ElementCount HalfEC = VecTy->getElementCount().divideCoefficientBy(2); + auto *HalfVec = ConstantVector::getSplat(HalfEC, EltC); + return ConstantStruct::get(StTy, HalfVec, HalfVec); } - if (isa(Vec->getType())) { - SmallVector Res0(NumElements), Res1(NumElements); - for (unsigned I = 0; I < NumElements; ++I) { - Constant *Elt0 = Vec->getAggregateElement(2 * I); - Constant *Elt1 = Vec->getAggregateElement(2 * I + 1); - if (!Elt0 || !Elt1) - return nullptr; - Res0[I] = Elt0; - Res1[I] = Elt1; - } - return ConstantStruct::get(StTy, ConstantVector::get(Res0), - ConstantVector::get(Res1)); + + if (!isa(Vec->getType())) + return nullptr; + + unsigned NumElements = VecTy->getElementCount().getFixedValue() / 2; + SmallVector Res0(NumElements), Res1(NumElements); + for (unsigned I = 0; I < NumElements; ++I) { + Constant *Elt0 = Vec->getAggregateElement(2 * I); + Constant *Elt1 = Vec->getAggregateElement(2 * I + 1); + if (!Elt0 || !Elt1) + return nullptr; + Res0[I] = Elt0; + Res1[I] = Elt1; } - return nullptr; + return ConstantStruct::get(StTy, ConstantVector::get(Res0), + ConstantVector::get(Res1)); } default: // TODO: Constant folding of vector intrinsics that fall through here does diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll index 9dbe3d4e50ee1..14543f339db5d 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll @@ -66,3 +66,19 @@ define {, } @fold_scalable_vector_deinterlea %1 = call {, } @llvm.vector.deinterleave2.v4i32.v8i32( zeroinitializer) ret {, } %1 } + +define {, } @fold_scalable_vector_deinterleave2_splat() { +; CHECK-LABEL: define { , } @fold_scalable_vector_deinterleave2_splat() { +; CHECK-NEXT: ret { , } { splat (i32 1), splat (i32 1) } +; + %1 = call {, } @llvm.vector.deinterleave2.v4i32.v8i32( splat (i32 1)) + ret {, } %1 +} + +define {, } @fold_scalable_vector_deinterleave2_splatfp() { +; CHECK-LABEL: define { , } @fold_scalable_vector_deinterleave2_splatfp() { +; CHECK-NEXT: ret { , } { splat (float 1.000000e+00), splat (float 1.000000e+00) } +; + %1 = call {, } @llvm.vector.deinterleave2.v4f32.v8f32( splat (float 1.0)) + ret {, } %1 +}