diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index 47417f53b6e40..78a92c86b91e4 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -372,6 +372,36 @@ inline bind_ty m_Type(LLT &Ty) { return Ty; } inline bind_ty m_Pred(CmpInst::Predicate &P) { return P; } inline operand_type_match m_Pred() { return operand_type_match(); } +template struct deferred_helper { + static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) { + return VR == V; + } +}; + +template <> struct deferred_helper { + static bool match(const MachineRegisterInfo &MRI, LLT VT, Register R) { + return VT == MRI.getType(R); + } +}; + +template struct deferred_ty { + Class &VR; + + deferred_ty(Class &V) : VR(V) {} + + template bool match(const MachineRegisterInfo &MRI, ITy &&V) { + return deferred_helper::match(MRI, VR, V); + } +}; + +/// Similar to m_SpecificReg/Type, but the specific value to match originated +/// from an earlier sub-pattern in the same mi_match expression. For example, +/// we cannot match `(add X, X)` with `m_GAdd(m_Reg(X), m_SpecificReg(X))` +/// because `X` is not initialized at the time it's passed to `m_SpecificReg`. +/// Instead, we can use `m_GAdd(m_Reg(x), m_DeferredReg(X))`. +inline deferred_ty m_DeferredReg(Register &R) { return R; } +inline deferred_ty m_DeferredType(LLT &Ty) { return Ty; } + struct ImplicitDefMatch { bool match(const MachineRegisterInfo &MRI, Register Reg) { MachineInstr *TmpMI; @@ -401,8 +431,13 @@ struct BinaryOp_match { if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) { return (L.match(MRI, TmpMI->getOperand(1).getReg()) && R.match(MRI, TmpMI->getOperand(2).getReg())) || - (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) && - L.match(MRI, TmpMI->getOperand(2).getReg()))); + // NOTE: When trying the alternative operand ordering + // with a commutative operation, it is imperative to always run + // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern + // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as + // expected. + (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) && + R.match(MRI, TmpMI->getOperand(1).getReg()))); } } return false; @@ -426,8 +461,13 @@ struct BinaryOpc_match { TmpMI->getNumOperands() == 3) { return (L.match(MRI, TmpMI->getOperand(1).getReg()) && R.match(MRI, TmpMI->getOperand(2).getReg())) || - (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) && - L.match(MRI, TmpMI->getOperand(2).getReg()))); + // NOTE: When trying the alternative operand ordering + // with a commutative operation, it is imperative to always run + // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern + // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as + // expected. + (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) && + R.match(MRI, TmpMI->getOperand(1).getReg()))); } } return false; @@ -674,6 +714,10 @@ struct CompareOp_match { Register RHS = TmpMI->getOperand(3).getReg(); if (L.match(MRI, LHS) && R.match(MRI, RHS)) return true; + // NOTE: When trying the alternative operand ordering + // with a commutative operation, it is imperative to always run + // the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern + // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as expected. if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) && P.match(MRI, CmpInst::getSwappedPredicate(TmpPred))) return true; diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index fc76d4055722e..40cd055c1c3f8 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -920,6 +920,36 @@ TEST_F(AArch64GISelMITest, MatchSpecificReg) { EXPECT_TRUE(mi_match(Add.getReg(0), *MRI, m_GAdd(m_SpecificReg(Reg), m_Reg()))); } +TEST_F(AArch64GISelMITest, DeferredMatching) { + setUp(); + if (!TM) + GTEST_SKIP(); + auto s64 = LLT::scalar(64); + auto s32 = LLT::scalar(32); + + auto Cst1 = B.buildConstant(s64, 42); + auto Cst2 = B.buildConstant(s64, 314); + auto Add = B.buildAdd(s64, Cst1, Cst2); + auto Sub = B.buildSub(s64, Add, Cst1); + + auto TruncAdd = B.buildTrunc(s32, Add); + auto TruncSub = B.buildTrunc(s32, Sub); + auto NarrowAdd = B.buildAdd(s32, TruncAdd, TruncSub); + + Register X; + EXPECT_TRUE(mi_match(Sub.getReg(0), *MRI, + m_GSub(m_GAdd(m_Reg(X), m_Reg()), m_DeferredReg(X)))); + LLT Ty; + EXPECT_TRUE( + mi_match(NarrowAdd.getReg(0), *MRI, + m_GAdd(m_GTrunc(m_Type(Ty)), m_GTrunc(m_DeferredType(Ty))))); + + // Test commutative. + auto Add2 = B.buildAdd(s64, Sub, Cst1); + EXPECT_TRUE(mi_match(Add2.getReg(0), *MRI, + m_GAdd(m_Reg(X), m_GSub(m_Reg(), m_DeferredReg(X))))); +} + } // namespace int main(int argc, char **argv) {