diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 252781e54ab06..00b0f05f18f03 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1966,30 +1966,6 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, SubOne(cast(Cmp.getOperand(1)))); } - // If we are testing the intersection of 2 select-of-nonzero-constants with no - // common bits set, it's the same as checking if exactly one select condition - // is set: - // ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B - // ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B) - // TODO: Generalize for non-constant values. - // TODO: Handle signed/unsigned predicates. - // TODO: Handle other bitwise logic connectors. - // TODO: Extend to handle a non-zero compare constant. - if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) { - assert(Cmp.isEquality() && "Not expecting non-equality predicates"); - Value *A, *B; - const APInt *TC, *FC; - if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) && - match(Y, - m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) && - !TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) { - Value *R = Builder.CreateXor(A, B); - if (Pred == CmpInst::ICMP_NE) - R = Builder.CreateNot(R); - return replaceInstUsesWith(Cmp, R); - } - } - // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) @@ -3110,6 +3086,44 @@ static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, return nullptr; } +Instruction *InstCombinerImpl::foldICmpBinOpWithConstantViaTruthTable( + ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { + Value *A, *B; + Constant *C1, *C2, *C3, *C4; + if (!(match(BO->getOperand(0), + m_Select(m_Value(A), m_Constant(C1), m_Constant(C2)))) || + !match(BO->getOperand(1), + m_Select(m_Value(B), m_Constant(C3), m_Constant(C4))) || + Cmp.getType() != A->getType()) + return nullptr; + + std::bitset<4> Table; + auto ComputeTable = [&](bool First, bool Second) -> std::optional { + Constant *L = First ? C1 : C2; + Constant *R = Second ? C3 : C4; + if (auto *Res = ConstantFoldBinaryOpOperands(BO->getOpcode(), L, R, DL)) { + auto *Val = Res->getType()->isVectorTy() ? Res->getSplatValue() : Res; + if (auto *CI = dyn_cast_or_null(Val)) + return ICmpInst::compare(CI->getValue(), C, Cmp.getPredicate()); + } + return std::nullopt; + }; + + for (unsigned I = 0; I < 4; ++I) { + bool First = (I >> 1) & 1; + bool Second = I & 1; + if (auto Res = ComputeTable(First, Second)) + Table[I] = *Res; + else + return nullptr; + } + + // Synthesize optimal logic. + if (auto *Cond = createLogicFromTable(Table, A, B, Builder, BO->hasOneUse())) + return replaceInstUsesWith(Cmp, Cond); + return nullptr; +} + /// Fold icmp (add X, Y), C. Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, @@ -4014,7 +4028,13 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, } // TODO: These folds could be refactored to be part of the above calls. - return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + return I; + + // Fall back to handling `icmp pred (select A ? C1 : C2) binop (select B ? C3 + // : C4), C5` pattern, by computing a truth table of the four constant + // variants. + return foldICmpBinOpWithConstantViaTruthTable(Cmp, BO, C); } static Instruction * diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 324738ef8c88e..8b657b3f8555c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -736,6 +736,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, const APInt &C2); + Instruction *foldICmpBinOpWithConstantViaTruthTable(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt &C); Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, const APInt &C); diff --git a/llvm/test/Transforms/InstCombine/icmp-binop.ll b/llvm/test/Transforms/InstCombine/icmp-binop.ll index 356489716fff9..161b290151e5a 100644 --- a/llvm/test/Transforms/InstCombine/icmp-binop.ll +++ b/llvm/test/Transforms/InstCombine/icmp-binop.ll @@ -359,3 +359,131 @@ define i1 @test_icmp_sgt_and_negpow2_invalid_c(i32 %add) { %cmp = icmp sgt i32 %and, 48 ret i1 %cmp } + +define i1 @icmp_eq_or_of_selects_with_constant(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant( +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %s1 = select i1 %a, i64 65536, i64 0 + %s2 = select i1 %b, i64 256, i64 0 + %or = or i64 %s1, %s2 + %cmp = icmp eq i64 %or, 65792 + ret i1 %cmp +} + +define i1 @icmp_slt_and_of_selects_with_constant(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_slt_and_of_selects_with_constant( +; CHECK-NEXT: [[TMP1:%.*]] = or i1 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true +; CHECK-NEXT: ret i1 [[CMP]] +; + %s1 = select i1 %a, i8 1, i8 254 + %s2 = select i1 %b, i8 1, i8 253 + %and = and i8 %s1, %s2 + %cmp = icmp slt i8 %and, 254 + ret i1 %cmp +} + +define i1 @icmp_sge_add_of_selects_with_constant(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_sge_add_of_selects_with_constant( +; CHECK-NEXT: ret i1 true +; + %s1 = select i1 %a, i8 248, i8 7 + %s2 = select i1 %b, i8 16, i8 0 + %add = add i8 %s1, %s2 + %cmp = icmp sge i8 %add, 247 + ret i1 %cmp +} + +define i1 @icmp_eq_or_of_selects_with_constant_multiuse_foldable(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_multiuse_foldable( +; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], i64 65536, i64 0 +; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], i64 256, i64 0 +; CHECK-NEXT: [[OR:%.*]] = or disjoint i64 [[S1]], [[S2]] +; CHECK-NEXT: call void @use64(i64 [[OR]]) +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[A]], [[B]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %s1 = select i1 %a, i64 65536, i64 0 + %s2 = select i1 %b, i64 256, i64 0 + %or = or i64 %s1, %s2 + call void @use64(i64 %or) + %cmp = icmp eq i64 %or, 65792 + ret i1 %cmp +} + +define <2 x i1> @icmp_eq_or_of_selects_with_constant_vectorized(<2 x i1> %a, <2 x i1> %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_vectorized( +; CHECK-NEXT: [[CMP:%.*]] = and <2 x i1> [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %s1 = select <2 x i1> %a, <2 x i64> , <2 x i64> zeroinitializer + %s2 = select <2 x i1> %b, <2 x i64> , <2 x i64> zeroinitializer + %or = or <2 x i64> %s1, %s2 + %cmp = icmp eq <2 x i64> %or, + ret <2 x i1> %cmp +} + +define <2 x i1> @icmp_eq_or_of_selects_with_scalar_cond_constant_vectorized(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_scalar_cond_constant_vectorized( +; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], <2 x i64> splat (i64 65536), <2 x i64> zeroinitializer +; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], <2 x i64> splat (i64 256), <2 x i64> zeroinitializer +; CHECK-NEXT: [[OR:%.*]] = or disjoint <2 x i64> [[S1]], [[S2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[OR]], splat (i64 65792) +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %s1 = select i1 %a, <2 x i64> , <2 x i64> zeroinitializer + %s2 = select i1 %b, <2 x i64> , <2 x i64> zeroinitializer + %or = or <2 x i64> %s1, %s2 + %cmp = icmp eq <2 x i64> %or, + ret <2 x i1> %cmp +} + +; Negative tests. +define i1 @icmp_eq_or_of_selects_with_constant_and_arg(i1 %a, i1 %b, i64 %arg) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_and_arg( +; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], i64 65536, i64 [[ARG:%.*]] +; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], i64 256, i64 0 +; CHECK-NEXT: [[OR:%.*]] = or i64 [[S1]], [[S2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR]], 65792 +; CHECK-NEXT: ret i1 [[CMP]] +; + %s1 = select i1 %a, i64 65536, i64 %arg + %s2 = select i1 %b, i64 256, i64 0 + %or = or i64 %s1, %s2 + %cmp = icmp eq i64 %or, 65792 + ret i1 %cmp +} + +define i1 @icmp_eq_or_of_selects_with_constant_multiuse(i1 %a, i1 %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_multiuse( +; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], i64 0, i64 65536 +; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], i64 0, i64 256 +; CHECK-NEXT: [[OR:%.*]] = or disjoint i64 [[S1]], [[S2]] +; CHECK-NEXT: call void @use64(i64 [[OR]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR]], 65792 +; CHECK-NEXT: ret i1 [[CMP]] +; + %s1 = select i1 %a, i64 0, i64 65536 + %s2 = select i1 %b, i64 0, i64 256 + %or = or i64 %s1, %s2 + call void @use64(i64 %or) + %cmp = icmp eq i64 %or, 65792 + ret i1 %cmp +} + +define <2 x i1> @icmp_eq_or_of_selects_with_constant_vectorized_nonsplat(<2 x i1> %a, <2 x i1> %b) { +; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_vectorized_nonsplat( +; CHECK-NEXT: [[S1:%.*]] = select <2 x i1> [[A:%.*]], <2 x i64> splat (i64 65536), <2 x i64> zeroinitializer +; CHECK-NEXT: [[S2:%.*]] = select <2 x i1> [[B:%.*]], <2 x i64> , <2 x i64> zeroinitializer +; CHECK-NEXT: [[OR:%.*]] = or disjoint <2 x i64> [[S1]], [[S2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[OR]], +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %s1 = select <2 x i1> %a, <2 x i64> , <2 x i64> zeroinitializer + %s2 = select <2 x i1> %b, <2 x i64> , <2 x i64> zeroinitializer + %or = or <2 x i64> %s1, %s2 + %cmp = icmp eq <2 x i64> %or, + ret <2 x i1> %cmp +} diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll index 1aae91302dab1..a038731abbc48 100644 --- a/llvm/test/Transforms/InstCombine/icmp-select.ll +++ b/llvm/test/Transforms/InstCombine/icmp-select.ll @@ -328,10 +328,7 @@ define i1 @select_constants_and_icmp_eq0_common_bit(i1 %x, i1 %y) { define i1 @select_constants_and_icmp_eq0_no_common_op1(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_eq0_no_common_op1( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 24, i8 3 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 0 +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 16, i8 3 @@ -345,10 +342,7 @@ define i1 @select_constants_and_icmp_eq0_no_common_op1(i1 %x, i1 %y) { define i1 @select_constants_and_icmp_eq0_no_common_op2(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_eq0_no_common_op2( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 16, i8 7 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 0 +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 16, i8 3 @@ -387,14 +381,9 @@ define i1 @select_constants_and_icmp_eq0_zero_fval(i1 %x, i1 %y) { ret i1 %cmp } -; TODO: x & y - define i1 @select_constants_and_icmp_eq_tval(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_eq_tval( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 6, i8 1 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 6, i8 1 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 6 +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 6, i8 1 @@ -404,14 +393,10 @@ define i1 @select_constants_and_icmp_eq_tval(i1 %x, i1 %y) { ret i1 %cmp } -; TODO: ~(x | y) - define i1 @select_constants_and_icmp_eq_fval(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_eq_fval( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 12, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 12, i8 3 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = or i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 12, i8 3 @@ -512,10 +497,8 @@ define i1 @select_constants_and_icmp_ne0_common_bit(i1 %x, i1 %y) { define i1 @select_constants_and_icmp_ne0_no_common_op1(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_ne0_no_common_op1( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 24, i8 3 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 16, i8 3 @@ -529,10 +512,8 @@ define i1 @select_constants_and_icmp_ne0_no_common_op1(i1 %x, i1 %y) { define i1 @select_constants_and_icmp_ne0_no_common_op2(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_ne0_no_common_op2( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 16, i8 7 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 16, i8 3 @@ -571,14 +552,10 @@ define i1 @select_constants_and_icmp_ne0_zero_fval(i1 %x, i1 %y) { ret i1 %cmp } -; TODO: ~(x & y) - define i1 @select_constants_and_icmp_ne_tval(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_ne_tval( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 6, i8 1 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 6, i8 1 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 6 +; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 6, i8 1 @@ -588,14 +565,9 @@ define i1 @select_constants_and_icmp_ne_tval(i1 %x, i1 %y) { ret i1 %cmp } -; TODO: (x | y) - define i1 @select_constants_and_icmp_ne_fval(i1 %x, i1 %y) { ; CHECK-LABEL: @select_constants_and_icmp_ne_fval( -; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 12, i8 3 -; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 12, i8 3 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 3 +; CHECK-NEXT: [[CMP:%.*]] = or i1 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %s1 = select i1 %x, i8 12, i8 3