diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 999ad1adff20b..4e86f30100dc6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3645,12 +3645,48 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl &Affected, return false; } +// Checks for following pattern: +// ``` +// %any1 = select i1 %any0, float 1.000000e+00, float 0.000000e+00 +// ``` +// which then gets folded into: +// ``` +// %any1 = uitofp i1 %any0 to float +// ``` +// (also works with double) +static std::optional +mabyeFoldIntoCast(Value *CondVal, ConstantFP *TrueVal, ConstantFP *FalseVal, + Type *SelType, llvm::StringRef out) { + if (TrueVal->getValueAPF().convertToDouble() != 1.0) { + return std::optional(); + } + + if (FalseVal->getValueAPF().convertToDouble() != 0.0) { + return std::optional(); + } + + return CastInst::Create(llvm::Instruction::UIToFP, CondVal, SelType, out); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); + if (ConstantFP *True = dyn_cast(TrueVal)) { + if (ConstantFP *False = dyn_cast(FalseVal)) { + if (SelType->isFloatTy() || SelType->isDoubleTy()) { + std::optional folded = + mabyeFoldIntoCast(CondVal, True, False, SelType, SI.getName()); + + if (folded.has_value()) { + return folded.value(); + } + } + } + } + if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); diff --git a/llvm/test/Transforms/InstCombine/2024-11-07-FoldSelectIntoCast.ll b/llvm/test/Transforms/InstCombine/2024-11-07-FoldSelectIntoCast.ll new file mode 100644 index 0000000000000..98af70d018862 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/2024-11-07-FoldSelectIntoCast.ll @@ -0,0 +1,8 @@ +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define noundef float @ifelse(i1 noundef zeroext %x) unnamed_addr { +start: +; CHECK: %.= uitofp i1 %x to float + %. = select i1 %x, float 1.000000e+00, float 0.000000e+00 + ret float %. +} \ No newline at end of file