From c6ae85cd495841fd8fde69aa915828423ac5c3a5 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 13 Jun 2025 12:02:53 -0700 Subject: [PATCH] [ConstantFolding] Fold deinterleave2 of any splat vector not just zeroinitializer While there remove an unnecessary dyn_cast from Constant to Constant. Reverse a branch condition into an early out to reduce nesting. --- llvm/lib/Analysis/ConstantFolding.cpp | 43 +++++++++---------- .../InstSimplify/ConstProp/vector-calls.ll | 16 +++++++ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 64a0f4641250c..2b7a438a9ef01 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3990,31 +3990,30 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID, return ConstantStruct::get(StTy, SinResult, CosResult); } case Intrinsic::vector_deinterleave2: { - auto *Vec = dyn_cast(Operands[0]); - if (!Vec) - return nullptr; - + auto *Vec = Operands[0]; auto *VecTy = cast(Vec->getType()); - unsigned NumElements = VecTy->getElementCount().getKnownMinValue() / 2; - if (isa(Vec)) { - auto *HalfVecTy = VectorType::getHalfElementsVectorType(VecTy); - return ConstantStruct::get(StTy, ConstantAggregateZero::get(HalfVecTy), - ConstantAggregateZero::get(HalfVecTy)); + + if (auto *EltC = Vec->getSplatValue()) { + ElementCount HalfEC = VecTy->getElementCount().divideCoefficientBy(2); + auto *HalfVec = ConstantVector::getSplat(HalfEC, EltC); + return ConstantStruct::get(StTy, HalfVec, HalfVec); } - if (isa(Vec->getType())) { - SmallVector Res0(NumElements), Res1(NumElements); - for (unsigned I = 0; I < NumElements; ++I) { - Constant *Elt0 = Vec->getAggregateElement(2 * I); - Constant *Elt1 = Vec->getAggregateElement(2 * I + 1); - if (!Elt0 || !Elt1) - return nullptr; - Res0[I] = Elt0; - Res1[I] = Elt1; - } - return ConstantStruct::get(StTy, ConstantVector::get(Res0), - ConstantVector::get(Res1)); + + if (!isa(Vec->getType())) + return nullptr; + + unsigned NumElements = VecTy->getElementCount().getFixedValue() / 2; + SmallVector Res0(NumElements), Res1(NumElements); + for (unsigned I = 0; I < NumElements; ++I) { + Constant *Elt0 = Vec->getAggregateElement(2 * I); + Constant *Elt1 = Vec->getAggregateElement(2 * I + 1); + if (!Elt0 || !Elt1) + return nullptr; + Res0[I] = Elt0; + Res1[I] = Elt1; } - return nullptr; + return ConstantStruct::get(StTy, ConstantVector::get(Res0), + ConstantVector::get(Res1)); } default: // TODO: Constant folding of vector intrinsics that fall through here does diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll index 9dbe3d4e50ee1..14543f339db5d 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll @@ -66,3 +66,19 @@ define {, } @fold_scalable_vector_deinterlea %1 = call {, } @llvm.vector.deinterleave2.v4i32.v8i32( zeroinitializer) ret {, } %1 } + +define {, } @fold_scalable_vector_deinterleave2_splat() { +; CHECK-LABEL: define { , } @fold_scalable_vector_deinterleave2_splat() { +; CHECK-NEXT: ret { , } { splat (i32 1), splat (i32 1) } +; + %1 = call {, } @llvm.vector.deinterleave2.v4i32.v8i32( splat (i32 1)) + ret {, } %1 +} + +define {, } @fold_scalable_vector_deinterleave2_splatfp() { +; CHECK-LABEL: define { , } @fold_scalable_vector_deinterleave2_splatfp() { +; CHECK-NEXT: ret { , } { splat (float 1.000000e+00), splat (float 1.000000e+00) } +; + %1 = call {, } @llvm.vector.deinterleave2.v4f32.v8f32( splat (float 1.0)) + ret {, } %1 +}