From 955f620e0dabbc287a54266499a5bd6c9ad0196c Mon Sep 17 00:00:00 2001 From: rbajpai Date: Tue, 10 Dec 2024 22:48:07 +0530 Subject: [PATCH] [InstCombine] Fix constant swap case of fcmp + fadd + sel xfrm The fcmp + fadd + sel => fcmp + sel + fadd xfrm performs incorrect transformation when select branch values are swapped. This change fixes this. --- .../InstCombine/InstCombineSelect.cpp | 44 +++++++++++-------- .../InstCombine/fcmp-fadd-select.ll | 16 +++---- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 3d251d662bd53..e7a8e947705f8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3769,22 +3769,9 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, if (!SIFOp || !SIFOp->hasNoSignedZeros() || !SIFOp->hasNoNaNs()) return nullptr; - // select((fcmp Pred, X, 0), (fadd X, C), C) - // => fadd((select (fcmp Pred, X, 0), X, 0), C) - // - // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE - Instruction *FAdd; - Constant *C; - Value *X, *Z; - CmpPredicate Pred; - - // Note: OneUse check for `Cmp` is necessary because it makes sure that other - // InstCombine folds don't undo this transformation and cause an infinite - // loop. Furthermore, it could also increase the operation count. - if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), - m_OneUse(m_Instruction(FAdd)), m_Constant(C))) || - match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), - m_Constant(C), m_OneUse(m_Instruction(FAdd))))) { + auto TryFoldIntoAddConstant = + [&Builder, &SI](CmpInst::Predicate Pred, Value *X, Value *Z, + Instruction *FAdd, Constant *C, bool Swapped) -> Value * { // Only these relational predicates can be transformed into maxnum/minnum // intrinsic. if (!CmpInst::isRelational(Pred) || !match(Z, m_AnyZeroFP())) @@ -3793,7 +3780,8 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C)))) return nullptr; - Value *NewSelect = Builder.CreateSelect(SI.getCondition(), X, Z, "", &SI); + Value *NewSelect = Builder.CreateSelect(SI.getCondition(), Swapped ? Z : X, + Swapped ? X : Z, "", &SI); NewSelect->takeName(&SI); Value *NewFAdd = Builder.CreateFAdd(NewSelect, C); @@ -3808,7 +3796,27 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, cast(NewSelect)->setFastMathFlags(NewFMF); return NewFAdd; - } + }; + + // select((fcmp Pred, X, 0), (fadd X, C), C) + // => fadd((select (fcmp Pred, X, 0), X, 0), C) + // + // Pred := OGT, OGE, OLT, OLE, UGT, UGE, ULT, and ULE + Instruction *FAdd; + Constant *C; + Value *X, *Z; + CmpPredicate Pred; + + // Note: OneUse check for `Cmp` is necessary because it makes sure that other + // InstCombine folds don't undo this transformation and cause an infinite + // loop. Furthermore, it could also increase the operation count. + if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), + m_OneUse(m_Instruction(FAdd)), m_Constant(C)))) + return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, /*Swapped=*/false); + + if (match(&SI, m_Select(m_OneUse(m_FCmp(Pred, m_Value(X), m_Value(Z))), + m_Constant(C), m_OneUse(m_Instruction(FAdd))))) + return TryFoldIntoAddConstant(Pred, X, Z, FAdd, C, /*Swapped=*/true); return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll index 0d0af91608e7a..15fad55db8df1 100644 --- a/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll +++ b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll @@ -19,7 +19,7 @@ define float @test_fcmp_ogt_fadd_select_constant(float %in) { define float @test_fcmp_ogt_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_ogt_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { -; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00) +; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00) ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -87,7 +87,7 @@ define float @test_fcmp_olt_fadd_select_constant(float %in) { define float @test_fcmp_olt_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_olt_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { -; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00) +; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00) ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -155,7 +155,7 @@ define float @test_fcmp_oge_fadd_select_constant(float %in) { define float @test_fcmp_oge_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_oge_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { -; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00) +; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00) ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -223,7 +223,7 @@ define float @test_fcmp_ole_fadd_select_constant(float %in) { define float @test_fcmp_ole_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_ole_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { -; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.minnum.f32(float [[IN]], float 0.000000e+00) +; CHECK-NEXT: [[SEL_NEW:%.*]] = call nsz float @llvm.maxnum.f32(float [[IN]], float 0.000000e+00) ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -293,7 +293,7 @@ define float @test_fcmp_ugt_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_ugt_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { ; CHECK-NEXT: [[CMP1_INV:%.*]] = fcmp ole float [[IN]], 0.000000e+00 -; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]] +; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00 ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -366,7 +366,7 @@ define float @test_fcmp_uge_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_uge_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { ; CHECK-NEXT: [[CMP1_INV:%.*]] = fcmp olt float [[IN]], 0.000000e+00 -; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]] +; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00 ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -439,7 +439,7 @@ define float @test_fcmp_ult_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_ult_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { ; CHECK-NEXT: [[CMP1_INV:%.*]] = fcmp oge float [[IN]], 0.000000e+00 -; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]] +; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00 ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ; @@ -512,7 +512,7 @@ define float @test_fcmp_ule_fadd_select_constant_swapped(float %in) { ; CHECK-LABEL: define float @test_fcmp_ule_fadd_select_constant_swapped( ; CHECK-SAME: float [[IN:%.*]]) { ; CHECK-NEXT: [[CMP1_INV:%.*]] = fcmp ogt float [[IN]], 0.000000e+00 -; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float 0.000000e+00, float [[IN]] +; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1_INV]], float [[IN]], float 0.000000e+00 ; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd nnan nsz float [[SEL_NEW]], 1.000000e+00 ; CHECK-NEXT: ret float [[ADD_NEW]] ;