From 7e22e93fcda75d4f03624a68f00685d68bb6d32b Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Fri, 3 Jan 2025 01:33:41 +0530 Subject: [PATCH 1/4] [InstCombine] Pre-Commit Tests --- .../Transforms/InstCombine/select_frexp.ll | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/select_frexp.ll diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll new file mode 100644 index 0000000000000..b3f05f4db42dd --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=instcombine -S < %s | FileCheck %s + +declare { float, i32 } @llvm.frexp.f32.i32(float) +declare void @use(float) + +; Basic test case - constant in true position +define float @test_select_frexp_basic(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_basic( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Test with constant in false position +define float @test_select_frexp_const_false(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_const_false( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00 +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float %x, float 1.000000e+00 + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Multi-use test +define float @test_select_frexp_multi_use(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_multi_use( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: call void @use(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + call void @use(float %sel) + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Vector test - splat constant +define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Vector test with poison +define <2 x float> @test_select_frexp_vec_poison(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_poison( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> , <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Vector test - non-splat (should not fold) +define <2 x float> @test_select_frexp_vec_nonsplat(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_nonsplat( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> , <2 x float> [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 +; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; + %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + +; Negative test - both operands non-constant +define float @test_select_frexp_no_const(float %x, float %y, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_no_const( +; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float [[Y]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 +; CHECK-NEXT: ret float [[FREXP_0]] +; + %sel = select i1 %cond, float %x, float %y + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + +; Negative test - extracting exp instead of mantissa +define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) { +; CHECK-LABEL: define i32 @test_select_frexp_extract_exp( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP_1:%.*]] = extractvalue { float, i32 } [[FREXP]], 1 +; CHECK-NEXT: ret i32 [[FREXP_1]] +; + %sel = select i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.1 = extractvalue { float, i32 } %frexp, 1 + ret i32 %frexp.1 +} + +declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) From b738a1dcfc577b27002b53ec7008f0e26753cdb6 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Fri, 3 Jan 2025 18:03:45 +0530 Subject: [PATCH 2/4] [InstCombine] InstCombine should fold frexp of select to select of frexp --- .../InstCombine/InstructionCombining.cpp | 67 ++++++++++++++++++- .../Transforms/InstCombine/select_frexp.ll | 17 ++--- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index a64c188575e6c..54919ea2f7386 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4069,6 +4069,52 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { return nullptr; } +static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall, + SelectInst *SelectInst, + InstCombiner::BuilderTy &Builder) { + // Helper to fold frexp of select to select of frexp. + Value *Cond = SelectInst->getCondition(); + Value *TrueVal = SelectInst->getTrueValue(); + Value *FalseVal = SelectInst->getFalseValue(); + ConstantFP *ConstOp = nullptr; + Value *VarOp = nullptr; + bool ConstIsTrue = false; + + if (auto *TrueConst = dyn_cast(TrueVal)) { + ConstOp = TrueConst; + VarOp = FalseVal; + ConstIsTrue = true; + } else if (auto *FalseConst = dyn_cast(FalseVal)) { + ConstOp = FalseConst; + VarOp = TrueVal; + ConstIsTrue = false; + } + + if (!ConstOp || !VarOp) + return nullptr; + + CallInst *NewFrexp = + Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp"); + + Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa"); + + APFloat ConstVal = ConstOp->getValueAPF(); + int Exp = 0; + APFloat Mantissa = ConstVal; + + if (ConstVal.isFiniteNonZero()) { + Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven); + } + + Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa); + + Value *NewSel = Builder.CreateSelect( + Cond, ConstIsTrue ? ConstantMantissa : NewEV, + ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); + + return NewSel; +} + Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -4078,7 +4124,26 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); - + if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { + if (auto *FrexpCall = dyn_cast(Agg)) { + if (Function *F = FrexpCall->getCalledFunction()) { + if (F->getIntrinsicID() == Intrinsic::frexp) { + if (auto *SelInst = + dyn_cast(FrexpCall->getArgOperand(0))) { + if (isa(SelInst->getTrueValue()) || + isa(SelInst->getFalseValue())) { + Builder.SetInsertPoint(&EV); + + if (Value *Result = + foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { + return replaceInstUsesWith(EV, Result); + } + } + } + } + } + } + } if (InsertValueInst *IV = dyn_cast(Agg)) { // We're extracting from an insertvalue instruction, compare the indices const unsigned *exti, *exte, *insi, *inse; diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index b3f05f4db42dd..652d4de27b759 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -8,10 +8,10 @@ declare void @use(float) define float @test_select_frexp_basic(float %x, i1 %cond) { ; CHECK-LABEL: define float @test_select_frexp_basic( ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float 1.000000e+00, float %x %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) @@ -23,10 +23,10 @@ define float @test_select_frexp_basic(float %x, i1 %cond) { define float @test_select_frexp_const_false(float %x, i1 %cond) { ; CHECK-LABEL: define float @test_select_frexp_const_false( ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00 -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float [[FREXP_0]], float 5.000000e-01 +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float %x, float 1.000000e+00 %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) @@ -40,9 +40,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] ; CHECK-NEXT: call void @use(float [[SEL]]) -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: ret float [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] ; %sel = select i1 %cond, float 1.000000e+00, float %x call void @use(float %sel) From 1e656c8957a5e3a608c694060c38d94327105821 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Mon, 6 Jan 2025 23:08:37 +0530 Subject: [PATCH 3/4] [InstCombine] Refactor and Preserve fast math flags --- .../InstCombine/InstructionCombining.cpp | 58 +++++++++---------- .../Transforms/InstCombine/select_frexp.ll | 37 +++++++++++- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 54919ea2f7386..bc07c7c047efb 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -33,6 +33,7 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -4069,52 +4070,57 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { return nullptr; } -static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall, +static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, SelectInst *SelectInst, InstCombiner::BuilderTy &Builder) { // Helper to fold frexp of select to select of frexp. Value *Cond = SelectInst->getCondition(); Value *TrueVal = SelectInst->getTrueValue(); Value *FalseVal = SelectInst->getFalseValue(); - ConstantFP *ConstOp = nullptr; + + const APFloat *ConstVal = nullptr; Value *VarOp = nullptr; bool ConstIsTrue = false; - if (auto *TrueConst = dyn_cast(TrueVal)) { - ConstOp = TrueConst; + if (match(TrueVal, m_APFloat(ConstVal))) { VarOp = FalseVal; ConstIsTrue = true; - } else if (auto *FalseConst = dyn_cast(FalseVal)) { - ConstOp = FalseConst; + } else if (match(FalseVal, m_APFloat(ConstVal))) { VarOp = TrueVal; ConstIsTrue = false; + } else { + return nullptr; } - if (!ConstOp || !VarOp) - return nullptr; + Builder.SetInsertPoint(&EV); CallInst *NewFrexp = Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp"); + NewFrexp->copyIRFlags(FrexpCall); Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa"); - APFloat ConstVal = ConstOp->getValueAPF(); - int Exp = 0; - APFloat Mantissa = ConstVal; + int Exp; + APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven); - if (ConstVal.isFiniteNonZero()) { - Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven); + Constant *ConstantMantissa; + if (auto *VecTy = dyn_cast(TrueVal->getType())) { + SmallVector Elems( + VecTy->getElementCount().getFixedValue(), + ConstantFP::get(VecTy->getElementType(), Mantissa)); + ConstantMantissa = ConstantVector::get(Elems); + } else { + ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); } - Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa); - Value *NewSel = Builder.CreateSelect( Cond, ConstIsTrue ? ConstantMantissa : NewEV, ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); + if (auto *NewSelInst = dyn_cast(NewSel)) + NewSelInst->copyFastMathFlags(SelectInst); return NewSel; } - Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -4125,20 +4131,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { - if (auto *FrexpCall = dyn_cast(Agg)) { - if (Function *F = FrexpCall->getCalledFunction()) { - if (F->getIntrinsicID() == Intrinsic::frexp) { - if (auto *SelInst = - dyn_cast(FrexpCall->getArgOperand(0))) { - if (isa(SelInst->getTrueValue()) || - isa(SelInst->getFalseValue())) { - Builder.SetInsertPoint(&EV); - - if (Value *Result = - foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { - return replaceInstUsesWith(EV, Result); - } - } + if (auto *FrexpCall = dyn_cast(Agg)) { + if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) { + if (auto *SelInst = dyn_cast(FrexpCall->getArgOperand(0))) { + if (Value *Result = + foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { + return replaceInstUsesWith(EV, Result); } } } diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index 652d4de27b759..d729e7c700514 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -56,10 +56,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) { ; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat( ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]] -; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0 -; CHECK-NEXT: ret <2 x float> [[FREXP_0]] +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[FREXP_0]] +; CHECK-NEXT: ret <2 x float> [[SELECT_FREXP]] ; %sel = select <2 x i1> %cond, <2 x float> , <2 x float> %x %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) @@ -127,4 +127,35 @@ define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) { ret i32 %frexp.1 } +; Test with fast math flags +define float @test_select_frexp_fast_math_select(float %x, i1 %cond) { +; CHECK-LABEL: define float @test_select_frexp_fast_math_select( +; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { float, i32 } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select nnan ninf nsz i1 [[COND]], float 5.000000e-01, float [[MANTISSA]] +; CHECK-NEXT: ret float [[SELECT_FREXP]] +; + %sel = select nnan ninf nsz i1 %cond, float 1.000000e+00, float %x + %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel) + %frexp.0 = extractvalue { float, i32 } %frexp, 0 + ret float %frexp.0 +} + + +; Test vector case with fast math flags +define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %cond) { +; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_fast_math( +; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select nnan ninf nsz <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[MANTISSA]] +; CHECK-NEXT: ret <2 x float> [[SELECT_FREXP]] +; + %sel = select nnan ninf nsz <2 x i1> %cond, <2 x float> , <2 x float> %x + %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel) + %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0 + ret <2 x float> %frexp.0 +} + declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) From ad6c5019f66b621c96c0ed4158d0bb4875ebd3c5 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Thu, 30 Jan 2025 22:05:47 +0530 Subject: [PATCH 4/4] [InstCombine] Refactor PatternMatch and add scalable Vector tests --- .../InstCombine/InstructionCombining.cpp | 40 +++++++------------ .../Transforms/InstCombine/select_frexp.ll | 36 +++++++++++++++-- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index bc07c7c047efb..5621511570b58 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4074,6 +4074,9 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, SelectInst *SelectInst, InstCombiner::BuilderTy &Builder) { // Helper to fold frexp of select to select of frexp. + + if (!SelectInst->hasOneUse() || !FrexpCall->hasOneUse()) + return nullptr; Value *Cond = SelectInst->getCondition(); Value *TrueVal = SelectInst->getTrueValue(); Value *FalseVal = SelectInst->getFalseValue(); @@ -4103,22 +4106,11 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall, int Exp; APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven); - Constant *ConstantMantissa; - if (auto *VecTy = dyn_cast(TrueVal->getType())) { - SmallVector Elems( - VecTy->getElementCount().getFixedValue(), - ConstantFP::get(VecTy->getElementType(), Mantissa)); - ConstantMantissa = ConstantVector::get(Elems); - } else { - ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); - } + Constant *ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa); - Value *NewSel = Builder.CreateSelect( + Value *NewSel = Builder.CreateSelectFMF( Cond, ConstIsTrue ? ConstantMantissa : NewEV, - ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp"); - if (auto *NewSelInst = dyn_cast(NewSel)) - NewSelInst->copyFastMathFlags(SelectInst); - + ConstIsTrue ? NewEV : ConstantMantissa, SelectInst, "select.frexp"); return NewSel; } Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { @@ -4130,17 +4122,15 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); - if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) { - if (auto *FrexpCall = dyn_cast(Agg)) { - if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) { - if (auto *SelInst = dyn_cast(FrexpCall->getArgOperand(0))) { - if (Value *Result = - foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) { - return replaceInstUsesWith(EV, Result); - } - } - } - } + + Value *Cond, *TrueVal, *FalseVal; + if (match(&EV, m_ExtractValue<0>(m_Intrinsic(m_Select( + m_Value(Cond), m_Value(TrueVal), m_Value(FalseVal)))))) { + auto *SelInst = + cast(cast(Agg)->getArgOperand(0)); + if (Value *Result = + foldFrexpOfSelect(EV, cast(Agg), SelInst, Builder)) + return replaceInstUsesWith(EV, Result); } if (InsertValueInst *IV = dyn_cast(Agg)) { // We're extracting from an insertvalue instruction, compare the indices diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll index d729e7c700514..d025aedda7170 100644 --- a/llvm/test/Transforms/InstCombine/select_frexp.ll +++ b/llvm/test/Transforms/InstCombine/select_frexp.ll @@ -40,10 +40,9 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) { ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]] ; CHECK-NEXT: call void @use(float [[SEL]]) -; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]]) +; CHECK-NEXT: [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]]) ; CHECK-NEXT: [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0 -; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]] -; CHECK-NEXT: ret float [[SELECT_FREXP]] +; CHECK-NEXT: ret float [[FREXP_0]] ; %sel = select i1 %cond, float 1.000000e+00, float %x call void @use(float %sel) @@ -158,4 +157,35 @@ define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %co ret <2 x float> %frexp.0 } +; Test with scalable vectors with constant at True Position +define @test_select_frexp_scalable_vec0( %x, %cond) { +; CHECK-LABEL: define @test_select_frexp_scalable_vec0( +; CHECK-SAME: [[X:%.*]], [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { , } @llvm.frexp.nxv2f32.nxv2i32( [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { , } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select [[COND]], splat (float 5.000000e-01), [[MANTISSA]] +; CHECK-NEXT: ret [[SELECT_FREXP]] +; + %sel = select %cond, splat (float 1.000000e+00), %x + %frexp = call { , } @llvm.frexp.nxv2f32.nxv2i32( %sel) + %frexp.0 = extractvalue { , } %frexp, 0 + ret %frexp.0 +} + +; Test with scalable vectors with constant at False Position +define @test_select_frexp_scalable_vec1( %x, %cond) { +; CHECK-LABEL: define @test_select_frexp_scalable_vec1( +; CHECK-SAME: [[X:%.*]], [[COND:%.*]]) { +; CHECK-NEXT: [[FREXP1:%.*]] = call { , } @llvm.frexp.nxv2f32.nxv2i32( [[X]]) +; CHECK-NEXT: [[MANTISSA:%.*]] = extractvalue { , } [[FREXP1]], 0 +; CHECK-NEXT: [[SELECT_FREXP:%.*]] = select [[COND]], [[MANTISSA]], splat (float 5.000000e-01) +; CHECK-NEXT: ret [[SELECT_FREXP]] +; + %sel = select %cond, %x, splat (float 1.000000e+00) + %frexp = call { , } @llvm.frexp.nxv2f32.nxv2i32( %sel) + %frexp.0 = extractvalue { , } %frexp, 0 + ret %frexp.0 +} + declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>) +declare { , } @llvm.frexp.nxv2f32.nxv2i32()