diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index 78a92c86b91e4..72483fbea5805 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -33,6 +33,12 @@ template return P.match(MRI, &MI); } +template +[[nodiscard]] bool mi_match(const MachineInstr &MI, + const MachineRegisterInfo &MRI, Pattern &&P) { + return P.match(MRI, &MI); +} + // TODO: Extend for N use. template struct OneUse_match { SubPatternT SubPat; @@ -337,6 +343,19 @@ template <> struct bind_helper { } }; +template <> struct bind_helper { + static bool bind(const MachineRegisterInfo &MRI, const MachineInstr *&MI, + Register Reg) { + MI = MRI.getVRegDef(Reg); + return MI; + } + static bool bind(const MachineRegisterInfo &MRI, const MachineInstr *&MI, + const MachineInstr *Inst) { + MI = Inst; + return MI; + } +}; + template <> struct bind_helper { static bool bind(const MachineRegisterInfo &MRI, LLT &Ty, Register Reg) { Ty = MRI.getType(Reg); @@ -368,6 +387,9 @@ template struct bind_ty { inline bind_ty m_Reg(Register &R) { return R; } inline bind_ty m_MInstr(MachineInstr *&MI) { return MI; } +inline bind_ty m_MInstr(const MachineInstr *&MI) { + return MI; +} inline bind_ty m_Type(LLT &Ty) { return Ty; } inline bind_ty m_Pred(CmpInst::Predicate &P) { return P; } inline operand_type_match m_Pred() { return operand_type_match(); } @@ -418,7 +440,7 @@ inline bind_ty m_GFCst(const ConstantFP *&C) { return C; } // General helper for all the binary generic MI such as G_ADD/G_SUB etc template + bool Commutable = false, unsigned Flags = MachineInstr::NoFlags> struct BinaryOp_match { LHS_P L; RHS_P R; @@ -426,18 +448,20 @@ struct BinaryOp_match { BinaryOp_match(const LHS_P &LHS, const RHS_P &RHS) : L(LHS), R(RHS) {} template bool match(const MachineRegisterInfo &MRI, OpTy &&Op) { - MachineInstr *TmpMI; + const MachineInstr *TmpMI; if (mi_match(Op, MRI, m_MInstr(TmpMI))) { if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) { - return (L.match(MRI, TmpMI->getOperand(1).getReg()) && - R.match(MRI, TmpMI->getOperand(2).getReg())) || - // NOTE: When trying the alternative operand ordering - // with a commutative operation, it is imperative to always run - // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern - // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as - // expected. - (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) && - R.match(MRI, TmpMI->getOperand(1).getReg()))); + if ((!L.match(MRI, TmpMI->getOperand(1).getReg()) || + !R.match(MRI, TmpMI->getOperand(2).getReg())) && + // NOTE: When trying the alternative operand ordering + // with a commutative operation, it is imperative to always run + // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern + // (i.e. `R`). Otherwise, m_DeferredReg/Type will not work as + // expected. + (!Commutable || !L.match(MRI, TmpMI->getOperand(2).getReg()) || + !R.match(MRI, TmpMI->getOperand(1).getReg()))) + return false; + return (TmpMI->getFlags() & Flags) == Flags; } } return false; @@ -464,7 +488,7 @@ struct BinaryOpc_match { // NOTE: When trying the alternative operand ordering // with a commutative operation, it is imperative to always run // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern - // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as + // (i.e. `R`). Otherwise, m_DeferredReg/Type will not work as // expected. (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) && R.match(MRI, TmpMI->getOperand(1).getReg()))); @@ -559,6 +583,19 @@ inline BinaryOp_match m_GOr(const LHS &L, return BinaryOp_match(L, R); } +template +inline BinaryOp_match +m_GDisjointOr(const LHS &L, const RHS &R) { + return BinaryOp_match(L, R); +} + +template +inline auto m_GAddLike(const LHS &L, const RHS &R) { + return m_any_of(m_GAdd(L, R), m_GDisjointOr(L, R)); +} + template inline BinaryOp_match m_GShl(const LHS &L, const RHS &R) { @@ -717,7 +754,7 @@ struct CompareOp_match { // NOTE: When trying the alternative operand ordering // with a commutative operation, it is imperative to always run // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern - // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as expected. + // (i.e. `R`). Otherwise, m_DeferredReg/Type will not work as expected. if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) && P.match(MRI, CmpInst::getSwappedPredicate(TmpPred))) return true; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ee109277dfbba..6c61e3a613f6f 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -1030,9 +1030,7 @@ def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs), return CurDAG->isADDLike(SDValue(N,0)); }]> { let GISelPredicateCode = [{ - return MI.getOpcode() == TargetOpcode::G_ADD || - (MI.getOpcode() == TargetOpcode::G_OR && - MI.getFlag(MachineInstr::MIFlag::Disjoint)); + return mi_match(MI, MRI, m_GAddLike(m_Reg(), m_Reg())); }]; } diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index 40cd055c1c3f8..25eb67e981588 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -950,6 +950,29 @@ TEST_F(AArch64GISelMITest, DeferredMatching) { m_GAdd(m_Reg(X), m_GSub(m_Reg(), m_DeferredReg(X))))); } +TEST_F(AArch64GISelMITest, AddLike) { + setUp(); + if (!TM) + GTEST_SKIP(); + auto s64 = LLT::scalar(64); + + auto Cst1 = B.buildConstant(s64, 42); + auto Cst2 = B.buildConstant(s64, 314); + + auto Or1 = B.buildOr(s64, Cst1, Cst2, MachineInstr::Disjoint); + auto Or2 = B.buildOr(s64, Cst1, Cst2); + auto Add = B.buildAdd(s64, Cst1, Cst2); + auto Sub = B.buildSub(s64, Cst1, Cst2); + + EXPECT_TRUE(mi_match(Or1.getReg(0), *MRI, m_GDisjointOr(m_Reg(), m_Reg()))); + EXPECT_FALSE(mi_match(Or2.getReg(0), *MRI, m_GDisjointOr(m_Reg(), m_Reg()))); + + EXPECT_TRUE(mi_match(Add.getReg(0), *MRI, m_GAddLike(m_Reg(), m_Reg()))); + EXPECT_FALSE(mi_match(Sub.getReg(0), *MRI, m_GAddLike(m_Reg(), m_Reg()))); + EXPECT_TRUE(mi_match(Or1.getReg(0), *MRI, m_GAddLike(m_Reg(), m_Reg()))); + EXPECT_FALSE(mi_match(Or2.getReg(0), *MRI, m_GAddLike(m_Reg(), m_Reg()))); +} + } // namespace int main(int argc, char **argv) {