Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,36 @@ inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
inline operand_type_match m_Pred() { return operand_type_match(); }

template <typename BindTy> struct deferred_helper {
static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) {
return VR == V;
}
};

template <> struct deferred_helper<LLT> {
static bool match(const MachineRegisterInfo &MRI, LLT VT, Register R) {
return VT == MRI.getType(R);
}
};

template <typename Class> struct deferred_ty {
Class &VR;

deferred_ty(Class &V) : VR(V) {}

template <typename ITy> bool match(const MachineRegisterInfo &MRI, ITy &&V) {
return deferred_helper<Class>::match(MRI, VR, V);
}
};

/// Similar to m_SpecificReg/Type, but the specific value to match originated
/// from an earlier sub-pattern in the same mi_match expression. For example,
/// we cannot match `(add X, X)` with `m_GAdd(m_Reg(X), m_SpecificReg(X))`
/// because `X` is not initialized at the time it's passed to `m_SpecificReg`.
/// Instead, we can use `m_GAdd(m_Reg(x), m_DeferredReg(X))`.
inline deferred_ty<Register> m_DeferredReg(Register &R) { return R; }
inline deferred_ty<LLT> m_DeferredType(LLT &Ty) { return Ty; }

struct ImplicitDefMatch {
bool match(const MachineRegisterInfo &MRI, Register Reg) {
MachineInstr *TmpMI;
Expand Down Expand Up @@ -401,8 +431,13 @@ struct BinaryOp_match {
if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) {
return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
R.match(MRI, TmpMI->getOperand(2).getReg())) ||
(Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
L.match(MRI, TmpMI->getOperand(2).getReg())));
// NOTE: When trying the alternative operand ordering
// with a commutative operation, it is imperative to always run
// the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern
// (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
// expected.
(Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
R.match(MRI, TmpMI->getOperand(1).getReg())));
}
}
return false;
Expand All @@ -426,8 +461,13 @@ struct BinaryOpc_match {
TmpMI->getNumOperands() == 3) {
return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
R.match(MRI, TmpMI->getOperand(2).getReg())) ||
(Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
L.match(MRI, TmpMI->getOperand(2).getReg())));
// NOTE: When trying the alternative operand ordering
// with a commutative operation, it is imperative to always run
// the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern
// (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
// expected.
(Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
R.match(MRI, TmpMI->getOperand(1).getReg())));
}
}
return false;
Expand Down Expand Up @@ -674,6 +714,10 @@ struct CompareOp_match {
Register RHS = TmpMI->getOperand(3).getReg();
if (L.match(MRI, LHS) && R.match(MRI, RHS))
return true;
// NOTE: When trying the alternative operand ordering
// with a commutative operation, it is imperative to always run
// the LHS sub-pattern (i.e. `L`) before the RHS sub-pattern
// (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as expected.
if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) &&
P.match(MRI, CmpInst::getSwappedPredicate(TmpPred)))
return true;
Expand Down
30 changes: 30 additions & 0 deletions llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,36 @@ TEST_F(AArch64GISelMITest, MatchSpecificReg) {
EXPECT_TRUE(mi_match(Add.getReg(0), *MRI, m_GAdd(m_SpecificReg(Reg), m_Reg())));
}

TEST_F(AArch64GISelMITest, DeferredMatching) {
setUp();
if (!TM)
GTEST_SKIP();
auto s64 = LLT::scalar(64);
auto s32 = LLT::scalar(32);

auto Cst1 = B.buildConstant(s64, 42);
auto Cst2 = B.buildConstant(s64, 314);
auto Add = B.buildAdd(s64, Cst1, Cst2);
auto Sub = B.buildSub(s64, Add, Cst1);

auto TruncAdd = B.buildTrunc(s32, Add);
auto TruncSub = B.buildTrunc(s32, Sub);
auto NarrowAdd = B.buildAdd(s32, TruncAdd, TruncSub);

Register X;
EXPECT_TRUE(mi_match(Sub.getReg(0), *MRI,
m_GSub(m_GAdd(m_Reg(X), m_Reg()), m_DeferredReg(X))));
LLT Ty;
EXPECT_TRUE(
mi_match(NarrowAdd.getReg(0), *MRI,
m_GAdd(m_GTrunc(m_Type(Ty)), m_GTrunc(m_DeferredType(Ty)))));

// Test commutative.
auto Add2 = B.buildAdd(s64, Sub, Cst1);
EXPECT_TRUE(mi_match(Add2.getReg(0), *MRI,
m_GAdd(m_Reg(X), m_GSub(m_Reg(), m_DeferredReg(X)))));
}

} // namespace

int main(int argc, char **argv) {
Expand Down
Loading