Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,63 +192,92 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {

/// Matcher for a specific constant value.
struct SpecificConstantMatch {
int64_t RequestedVal;
SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {}
APInt RequestedVal;
SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const ref

bool match(const MachineRegisterInfo &MRI, Register Reg) {
int64_t MatchedVal;
return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal;
APInt MatchedVal;
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
else
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());

return APInt::isSameValue(MatchedVal, RequestedVal);
}
return false;
}
};

/// Matches a constant equal to \p RequestedValue.
inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) {
return SpecificConstantMatch(std::move(RequestedValue));
}

inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
return SpecificConstantMatch(RequestedValue);
return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true));
}

/// Matcher for a specific constant splat.
struct SpecificConstantSplatMatch {
int64_t RequestedVal;
SpecificConstantSplatMatch(int64_t RequestedVal)
: RequestedVal(RequestedVal) {}
APInt RequestedVal;
SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
}
};

/// Matches a constant splat of \p RequestedValue.
inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) {
return SpecificConstantSplatMatch(std::move(RequestedValue));
}

inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
return SpecificConstantSplatMatch(RequestedValue);
return SpecificConstantSplatMatch(
APInt(64, RequestedValue, /* isSigned */ true));
}

/// Matcher for a specific constant or constant splat.
struct SpecificConstantOrSplatMatch {
int64_t RequestedVal;
SpecificConstantOrSplatMatch(int64_t RequestedVal)
APInt RequestedVal;
SpecificConstantOrSplatMatch(APInt RequestedVal)
: RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
int64_t MatchedVal;
if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
return true;
APInt MatchedVal;
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
else
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());

if (APInt::isSameValue(MatchedVal, RequestedVal))
return true;
}
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
}
};

/// Matches a \p RequestedValue constant or a constant splat of \p
/// RequestedValue.
inline SpecificConstantOrSplatMatch
m_SpecificICstOrSplat(APInt RequestedValue) {
return SpecificConstantOrSplatMatch(std::move(RequestedValue));
}

inline SpecificConstantOrSplatMatch
m_SpecificICstOrSplat(int64_t RequestedValue) {
return SpecificConstantOrSplatMatch(RequestedValue);
return SpecificConstantOrSplatMatch(
APInt(64, RequestedValue, /* isSigned */ true));
}

///{
/// Convenience matchers for specific integer values.
inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
inline SpecificConstantMatch m_ZeroInt() {
return SpecificConstantMatch(APInt(64, 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use the isNullValue in APInt (really for all of these, should follow along with the IR pattern matcher structure)

}
inline SpecificConstantMatch m_AllOnesInt() {
return SpecificConstantMatch(-1);
return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use getAllOnes(64) too, although I am unsure about treating it like a isSigned. It might not be necessary here because we would not see a larger value?

}
///}

/// Matcher for a specific register.
struct SpecificRegisterMatch {
Expand Down
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);

/// Return true if the specified register is defined by G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
APInt SplatValue, bool AllowUndef);

/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);

/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
APInt SplatValue, bool AllowUndef);

/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI,
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,13 +1401,35 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
return false;
}

bool llvm::isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
APInt SplatValue, bool AllowUndef) {
if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) {
if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth())
return APInt::isSameValue(
SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue);
return APInt::isSameValue(
SplatValAndReg->Value,
SplatValue.sext(SplatValAndReg->Value.getBitWidth()));
}

return false;
}

bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef) {
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
AllowUndef);
}

bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
APInt SplatValue, bool AllowUndef) {
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
AllowUndef);
}

std::optional<APInt>
llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) {
if (auto SplatValAndReg =
Expand Down
40 changes: 40 additions & 0 deletions llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
auto FortyTwo = B.buildConstant(LLT::scalar(64), 42);
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42)));
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123)));
EXPECT_TRUE(
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42))));
EXPECT_FALSE(
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123))));

// Test that this works inside of a more complex pattern.
LLT s64 = LLT::scalar(64);
auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42)));
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42))));

// Wrong constant.
EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123)));
EXPECT_FALSE(
mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123))));

// No constant on the LHS.
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42))));
}

TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
Expand All @@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));

EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
m_SpecificICstSplat(APInt(64, 42))));
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
m_SpecificICstSplat(APInt(64, 43))));
EXPECT_FALSE(
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42))));

MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});

Expand All @@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));

EXPECT_TRUE(
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
EXPECT_FALSE(
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43))));
EXPECT_FALSE(
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42))));

MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
EXPECT_FALSE(
mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
}

TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
Expand All @@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));

EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
m_SpecificICstOrSplat(APInt(64, 42))));
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
m_SpecificICstOrSplat(APInt(64, 43))));
EXPECT_TRUE(
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));

MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});

Expand All @@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));

EXPECT_TRUE(
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
EXPECT_FALSE(
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43))));
EXPECT_FALSE(
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));

MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
EXPECT_TRUE(
mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
}

TEST_F(AArch64GISelMITest, MatchZeroInt) {
Expand Down