diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index c3163f70b847e..4bba2f406b4c1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -119,63 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, /// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC /// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, - InstCombiner::BuilderTy &Builder, - const SimplifyQuery &SQ) { +static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, Value *TrueVal, + Value *FalseVal, Value *V, const APInt &AndMask, + bool CreateAnd, + InstCombiner::BuilderTy &Builder) { const APInt *SelTC, *SelFC; - if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || - !match(Sel.getFalseValue(), m_APInt(SelFC))) + if (!match(TrueVal, m_APInt(SelTC)) || !match(FalseVal, m_APInt(SelFC))) return nullptr; - // If this is a vector select, we need a vector compare. Type *SelType = Sel.getType(); - if (SelType->isVectorTy() != CondVal->getType()->isVectorTy()) - return nullptr; - - Value *V; - APInt AndMask; - bool CreateAnd = false; - CmpPredicate Pred; - Value *CmpLHS, *CmpRHS; - - if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { - if (ICmpInst::isEquality(Pred)) { - if (!match(CmpRHS, m_Zero())) - return nullptr; - - V = CmpLHS; - const APInt *AndRHS; - if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) - return nullptr; - - AndMask = *AndRHS; - } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) { - assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?"); - AndMask = Res->Mask; - V = Res->X; - KnownBits Known = - computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel)); - AndMask &= Known.getMaxValue(); - if (!AndMask.isPowerOf2()) - return nullptr; - - Pred = Res->Pred; - CreateAnd = true; - } else { - return nullptr; - } - - } else if (auto *Trunc = dyn_cast(CondVal)) { - V = Trunc->getOperand(0); - AndMask = APInt(V->getType()->getScalarSizeInBits(), 1); - Pred = ICmpInst::ICMP_NE; - CreateAnd = !Trunc->hasNoUnsignedWrap(); - } else { - return nullptr; - } - if (Pred == ICmpInst::ICMP_NE) - std::swap(SelTC, SelFC); - // In general, when both constants are non-zero, we would need an offset to // replace the select. This would require more instructions than we started // with. But there's one special-case that we handle here because it can @@ -762,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal, - Value *FalseVal, + Value *FalseVal, Value *V, + const APInt &AndMask, bool CreateAnd, InstCombiner::BuilderTy &Builder) { - // Only handle integer compares. Also, if this is a vector select, we need a - // vector compare. - if (!TrueVal->getType()->isIntOrIntVectorTy() || - TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy()) - return nullptr; - - unsigned C1Log; - bool NeedAnd = false; - CmpPredicate Pred; - Value *CmpLHS, *CmpRHS; - - if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { - if (ICmpInst::isEquality(Pred)) { - if (!match(CmpRHS, m_Zero())) - return nullptr; - - const APInt *C1; - if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) - return nullptr; - - C1Log = C1->logBase2(); - } else { - auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred); - if (!Res || !Res->Mask.isPowerOf2()) - return nullptr; - - CmpLHS = Res->X; - Pred = Res->Pred; - C1Log = Res->Mask.logBase2(); - NeedAnd = true; - } - } else if (auto *Trunc = dyn_cast(CondVal)) { - CmpLHS = Trunc->getOperand(0); - C1Log = 0; - Pred = ICmpInst::ICMP_NE; - NeedAnd = !Trunc->hasNoUnsignedWrap(); - } else { + // Only handle integer compares. + if (!TrueVal->getType()->isIntOrIntVectorTy()) return nullptr; - } - Value *Y, *V = CmpLHS; + unsigned C1Log = AndMask.logBase2(); + Value *Y; BinaryOperator *BinOp; const APInt *C2; bool NeedXor; if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) { Y = TrueVal; BinOp = cast(FalseVal); - NeedXor = Pred == ICmpInst::ICMP_NE; + NeedXor = false; } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) { Y = FalseVal; BinOp = cast(TrueVal); - NeedXor = Pred == ICmpInst::ICMP_EQ; + NeedXor = true; } else { return nullptr; } @@ -834,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal, V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. - if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) > + if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd) > (CondVal->hasOneUse() + BinOp->hasOneUse())) return nullptr; - if (NeedAnd) { + if (CreateAnd) { // Insert the AND instruction on the input to the truncate. - APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log); - V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1)); + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); } if (C2Log > C1Log) { @@ -3797,6 +3714,70 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI, return nullptr; } +static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &SQ) { + // If this is a vector select, we need a vector compare. + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != CondVal->getType()->isVectorTy()) + return nullptr; + + Value *V; + APInt AndMask; + bool CreateAnd = false; + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; + + if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + if (ICmpInst::isEquality(Pred)) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + V = CmpLHS; + const APInt *AndRHS; + if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) { + assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?"); + AndMask = Res->Mask; + V = Res->X; + KnownBits Known = + computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel)); + AndMask &= Known.getMaxValue(); + if (!AndMask.isPowerOf2()) + return nullptr; + + Pred = Res->Pred; + CreateAnd = true; + } else { + return nullptr; + } + } else if (auto *Trunc = dyn_cast(CondVal)) { + V = Trunc->getOperand(0); + AndMask = APInt(V->getType()->getScalarSizeInBits(), 1); + Pred = ICmpInst::ICMP_NE; + CreateAnd = !Trunc->hasNoUnsignedWrap(); + } else { + return nullptr; + } + + if (Pred == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + if (Value *X = foldSelectICmpAnd(Sel, CondVal, TrueVal, FalseVal, V, AndMask, + CreateAnd, Builder)) + return X; + + if (Value *X = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, V, AndMask, + CreateAnd, Builder)) + return X; + + return nullptr; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3969,10 +3950,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; - if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder, SQ)) - return replaceInstUsesWith(SI, V); - - if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectBitTest(SI, CondVal, TrueVal, FalseVal, Builder, SQ)) return replaceInstUsesWith(SI, V); if (Instruction *Add = foldAddSubSelect(SI, Builder)) diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll index a424247b676e4..771fad66e961e 100644 --- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll +++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll @@ -1832,9 +1832,9 @@ define i8 @neg_select_trunc_or_2(i8 %x, i8 %y) { define i8 @select_icmp_bittest_range(i8 range(i8 0, 64) %a, i8 %y) { ; CHECK-LABEL: @select_icmp_bittest_range( -; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i8 [[A:%.*]], 32 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2 -; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i8 [[A:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2 +; CHECK-NEXT: [[RES:%.*]] = or i8 [[Y:%.*]], [[TMP2]] ; CHECK-NEXT: ret i8 [[RES]] ; %cmp = icmp ult i8 %a, 32