diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 50e50a91389e2..27c5d5ca08cd6 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -822,12 +822,52 @@ template struct bind_ty { } }; +/// Check whether the value has the given Class and matches the nested +/// pattern. Capture it into the provided variable if successful. +template struct bind_and_match_ty { + Class *&VR; + MatchTy Match; + + bind_and_match_ty(Class *&V, const MatchTy &Match) : VR(V), Match(Match) {} + + template bool match(ITy *V) const { + auto *CV = dyn_cast(V); + if (CV && Match.match(V)) { + VR = CV; + return true; + } + return false; + } +}; + /// Match a value, capturing it if we match. inline bind_ty m_Value(Value *&V) { return V; } inline bind_ty m_Value(const Value *&V) { return V; } +/// Match against the nested pattern, and capture the value if we match. +template +inline bind_and_match_ty m_Value(Value *&V, + const MatchTy &Match) { + return {V, Match}; +} + +/// Match against the nested pattern, and capture the value if we match. +template +inline bind_and_match_ty m_Value(const Value *&V, + const MatchTy &Match) { + return {V, Match}; +} + /// Match an instruction, capturing it if we match. inline bind_ty m_Instruction(Instruction *&I) { return I; } + +/// Match against the nested pattern, and capture the instruction if we match. +template +inline bind_and_match_ty +m_Instruction(Instruction *&I, const MatchTy &Match) { + return {I, Match}; +} + /// Match a unary operator, capturing it if we match. inline bind_ty m_UnOp(UnaryOperator *&I) { return I; } /// Match a binary operator, capturing it if we match. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 981c5271fb3f6..959444aef1fc3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1355,9 +1355,9 @@ Instruction *InstCombinerImpl:: // right-shift of X and a "select". Value *X, *Select; Instruction *LowBitsToSkip, *Extract; - if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd( - m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)), - m_Instruction(Extract))), + if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_Instruction( + Extract, m_LShr(m_Value(X), + m_Instruction(LowBitsToSkip)))), m_Value(Select)))) return nullptr; @@ -1763,13 +1763,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { Constant *C; // (add X, (sext/zext (icmp eq X, C))) // -> (select (icmp eq X, C), (add C, (sext/zext 1)), X) - auto CondMatcher = m_CombineAnd( - m_Value(Cond), - m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C))); + auto CondMatcher = + m_Value(Cond, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), + m_ImmConstant(C))); if (match(&I, - m_c_Add(m_Value(A), - m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) && + m_c_Add(m_Value(A), m_Value(Ext, m_ZExtOrSExt(CondMatcher)))) && Ext->hasOneUse()) { Value *Add = isa(Ext) ? InstCombiner::AddOne(C) : InstCombiner::SubOne(C); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 3beda6bc5ba38..b231c04319106 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2025,10 +2025,9 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, if (CountUses && !Op->hasOneUse()) return false; - if (match(Op, m_c_BinOp(FlippedOpcode, - m_CombineAnd(m_Value(X), - m_Not(m_c_BinOp(Opcode, m_A, m_B))), - m_C))) + if (match(Op, + m_c_BinOp(FlippedOpcode, + m_Value(X, m_Not(m_c_BinOp(Opcode, m_A, m_B))), m_C))) return !CountUses || X->hasOneUse(); return false; @@ -2079,10 +2078,10 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, // result is more undefined than a source: // (~(A & B) | C) & ~(C & (A ^ B)) --> (A ^ B ^ C) | ~(A | C) is invalid. if (Opcode == Instruction::Or && Op0->hasOneUse() && - match(Op1, m_OneUse(m_Not(m_CombineAnd( - m_Value(Y), - m_c_BinOp(Opcode, m_Specific(C), - m_c_Xor(m_Specific(A), m_Specific(B)))))))) { + match(Op1, + m_OneUse(m_Not(m_Value( + Y, m_c_BinOp(Opcode, m_Specific(C), + m_c_Xor(m_Specific(A), m_Specific(B)))))))) { // X = ~(A | B) // Y = (C | (A ^ B) Value *Or = cast(X)->getOperand(0); @@ -2098,12 +2097,11 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, if (match(Op0, m_OneUse(m_c_BinOp(FlippedOpcode, m_BinOp(FlippedOpcode, m_Value(B), m_Value(C)), - m_CombineAnd(m_Value(X), m_Not(m_Value(A)))))) || - match(Op0, m_OneUse(m_c_BinOp( - FlippedOpcode, - m_c_BinOp(FlippedOpcode, m_Value(C), - m_CombineAnd(m_Value(X), m_Not(m_Value(A)))), - m_Value(B))))) { + m_Value(X, m_Not(m_Value(A)))))) || + match(Op0, m_OneUse(m_c_BinOp(FlippedOpcode, + m_c_BinOp(FlippedOpcode, m_Value(C), + m_Value(X, m_Not(m_Value(A)))), + m_Value(B))))) { // X = ~A // (~A & B & C) | ~(A | B | C) --> ~(A | (B ^ C)) // (~A | B | C) & ~(A & B & C) --> (~A | (B ^ C)) @@ -2434,8 +2432,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // (-(X & 1)) & Y --> (X & 1) == 0 ? 0 : Y Value *Neg; if (match(&I, - m_c_And(m_CombineAnd(m_Value(Neg), - m_OneUse(m_Neg(m_And(m_Value(), m_One())))), + m_c_And(m_Value(Neg, m_OneUse(m_Neg(m_And(m_Value(), m_One())))), m_Value(Y)))) { Value *Cmp = Builder.CreateIsNull(Neg); return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y); @@ -3728,9 +3725,8 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I, const APInt *C1, *C2; if (match(&I, m_c_Or(m_ExtractValue<1>( - m_CombineAnd(m_Intrinsic( - m_Value(X), m_APInt(C1)), - m_Value(WOV))), + m_Value(WOV, m_Intrinsic( + m_Value(X), m_APInt(C1)))), m_OneUse(m_SpecificCmp(ICmpInst::ICMP_UGT, m_ExtractValue<0>(m_Deferred(WOV)), m_APInt(C2))))) && @@ -3988,12 +3984,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // ~(B & ?) | (A ^ B) --> ~((B & ?) & A) Instruction *And; if ((Op0->hasOneUse() || Op1->hasOneUse()) && - match(Op0, m_Not(m_CombineAnd(m_Instruction(And), - m_c_And(m_Specific(A), m_Value()))))) + match(Op0, + m_Not(m_Instruction(And, m_c_And(m_Specific(A), m_Value()))))) return BinaryOperator::CreateNot(Builder.CreateAnd(And, B)); if ((Op0->hasOneUse() || Op1->hasOneUse()) && - match(Op0, m_Not(m_CombineAnd(m_Instruction(And), - m_c_And(m_Specific(B), m_Value()))))) + match(Op0, + m_Not(m_Instruction(And, m_c_And(m_Specific(B), m_Value()))))) return BinaryOperator::CreateNot(Builder.CreateAnd(And, A)); // (~A | C) | (A ^ B) --> ~(A & B) | C @@ -4125,16 +4121,13 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // treating any non-zero result as overflow. In that case, we overflow if both // umul.with.overflow operands are != 0, as in that case the result can only // be 0, iff the multiplication overflows. - if (match(&I, - m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)), - m_Value(Ov)), - m_CombineAnd( - m_SpecificICmp(ICmpInst::ICMP_NE, - m_CombineAnd(m_ExtractValue<0>( - m_Deferred(UMulWithOv)), - m_Value(Mul)), - m_ZeroInt()), - m_Value(MulIsNotZero)))) && + if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(m_Value(UMulWithOv))), + m_Value(MulIsNotZero, + m_SpecificICmp( + ICmpInst::ICMP_NE, + m_Value(Mul, m_ExtractValue<0>( + m_Deferred(UMulWithOv))), + m_ZeroInt())))) && (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse()))) { Value *A, *B; if (match(UMulWithOv, m_Intrinsic( @@ -4151,9 +4144,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { const WithOverflowInst *WO; const Value *WOV; const APInt *C1, *C2; - if (match(&I, m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_CombineAnd( - m_WithOverflowInst(WO), m_Value(WOV))), - m_Value(Ov)), + if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>( + m_Value(WOV, m_WithOverflowInst(WO)))), m_OneUse(m_ICmp(Pred, m_ExtractValue<0>(m_Deferred(WOV)), m_APInt(C2))))) && (WO->getBinaryOp() == Instruction::Add || @@ -4501,8 +4493,7 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, Value *M; if (!match(&I, m_c_Xor(m_Value(B), m_OneUse(m_c_And( - m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)), - m_Value(D)), + m_Value(D, m_c_Xor(m_Deferred(B), m_Value(X))), m_Value(M)))))) return nullptr; @@ -5206,8 +5197,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // (X ^ C) ^ Y --> (X ^ Y) ^ C // Just like we do in other places, we completely avoid the fold // for constantexprs, at least to avoid endless combine loop. - if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X), - m_Unless(m_ConstantExpr())), + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(X, m_Unless(m_ConstantExpr())), m_ImmConstant(C1))), m_Value(Y)))) return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);