Skip to content

Commit 961a4aa

Browse files
authored
[GlobalISel] Add constant matcher for APInt (#151357)
Changed m_SpecificICst, m_SpecificICstSplat and m_SpecificICstorSplat to match against APInt as well.
1 parent d6dc433 commit 961a4aa

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,36 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {
192192

193193
/// Matcher for a specific constant value.
194194
struct SpecificConstantMatch {
195-
int64_t RequestedVal;
196-
SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {}
195+
APInt RequestedVal;
196+
SpecificConstantMatch(const APInt RequestedVal)
197+
: RequestedVal(RequestedVal) {}
197198
bool match(const MachineRegisterInfo &MRI, Register Reg) {
198-
int64_t MatchedVal;
199-
return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal;
199+
APInt MatchedVal;
200+
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
201+
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
202+
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
203+
else
204+
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
205+
206+
return APInt::isSameValue(MatchedVal, RequestedVal);
207+
}
208+
return false;
200209
}
201210
};
202211

203212
/// Matches a constant equal to \p RequestedValue.
213+
inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) {
214+
return SpecificConstantMatch(std::move(RequestedValue));
215+
}
216+
204217
inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
205-
return SpecificConstantMatch(RequestedValue);
218+
return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true));
206219
}
207220

208221
/// Matcher for a specific constant splat.
209222
struct SpecificConstantSplatMatch {
210-
int64_t RequestedVal;
211-
SpecificConstantSplatMatch(int64_t RequestedVal)
223+
APInt RequestedVal;
224+
SpecificConstantSplatMatch(const APInt RequestedVal)
212225
: RequestedVal(RequestedVal) {}
213226
bool match(const MachineRegisterInfo &MRI, Register Reg) {
214227
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
@@ -217,38 +230,56 @@ struct SpecificConstantSplatMatch {
217230
};
218231

219232
/// Matches a constant splat of \p RequestedValue.
233+
inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) {
234+
return SpecificConstantSplatMatch(std::move(RequestedValue));
235+
}
236+
220237
inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
221-
return SpecificConstantSplatMatch(RequestedValue);
238+
return SpecificConstantSplatMatch(
239+
APInt(64, RequestedValue, /* isSigned */ true));
222240
}
223241

224242
/// Matcher for a specific constant or constant splat.
225243
struct SpecificConstantOrSplatMatch {
226-
int64_t RequestedVal;
227-
SpecificConstantOrSplatMatch(int64_t RequestedVal)
244+
APInt RequestedVal;
245+
SpecificConstantOrSplatMatch(const APInt RequestedVal)
228246
: RequestedVal(RequestedVal) {}
229247
bool match(const MachineRegisterInfo &MRI, Register Reg) {
230-
int64_t MatchedVal;
231-
if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
232-
return true;
248+
APInt MatchedVal;
249+
if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
250+
if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
251+
RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
252+
else
253+
MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
254+
255+
if (APInt::isSameValue(MatchedVal, RequestedVal))
256+
return true;
257+
}
233258
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
234259
/* AllowUndef */ false);
235260
}
236261
};
237262

238263
/// Matches a \p RequestedValue constant or a constant splat of \p
239264
/// RequestedValue.
265+
inline SpecificConstantOrSplatMatch
266+
m_SpecificICstOrSplat(APInt RequestedValue) {
267+
return SpecificConstantOrSplatMatch(std::move(RequestedValue));
268+
}
269+
240270
inline SpecificConstantOrSplatMatch
241271
m_SpecificICstOrSplat(int64_t RequestedValue) {
242-
return SpecificConstantOrSplatMatch(RequestedValue);
272+
return SpecificConstantOrSplatMatch(
273+
APInt(64, RequestedValue, /* isSigned */ true));
243274
}
244275

245-
///{
246276
/// Convenience matchers for specific integer values.
247-
inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
277+
inline SpecificConstantMatch m_ZeroInt() {
278+
return SpecificConstantMatch(APInt::getZero(64));
279+
}
248280
inline SpecificConstantMatch m_AllOnesInt() {
249-
return SpecificConstantMatch(-1);
281+
return SpecificConstantMatch(APInt::getAllOnes(64));
250282
}
251-
///}
252283

253284
/// Matcher for a specific register.
254285
struct SpecificRegisterMatch {

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
459459
const MachineRegisterInfo &MRI,
460460
int64_t SplatValue, bool AllowUndef);
461461

462+
/// Return true if the specified register is defined by G_BUILD_VECTOR or
463+
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
464+
LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
465+
const MachineRegisterInfo &MRI,
466+
APInt SplatValue, bool AllowUndef);
467+
462468
/// Return true if the specified instruction is a G_BUILD_VECTOR or
463469
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
464470
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
465471
const MachineRegisterInfo &MRI,
466472
int64_t SplatValue, bool AllowUndef);
467473

474+
/// Return true if the specified instruction is a G_BUILD_VECTOR or
475+
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
476+
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
477+
const MachineRegisterInfo &MRI,
478+
APInt SplatValue, bool AllowUndef);
479+
468480
/// Return true if the specified instruction is a G_BUILD_VECTOR or
469481
/// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
470482
LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI,

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,13 +1401,35 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
14011401
return false;
14021402
}
14031403

1404+
bool llvm::isBuildVectorConstantSplat(const Register Reg,
1405+
const MachineRegisterInfo &MRI,
1406+
APInt SplatValue, bool AllowUndef) {
1407+
if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) {
1408+
if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth())
1409+
return APInt::isSameValue(
1410+
SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue);
1411+
return APInt::isSameValue(
1412+
SplatValAndReg->Value,
1413+
SplatValue.sext(SplatValAndReg->Value.getBitWidth()));
1414+
}
1415+
1416+
return false;
1417+
}
1418+
14041419
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
14051420
const MachineRegisterInfo &MRI,
14061421
int64_t SplatValue, bool AllowUndef) {
14071422
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
14081423
AllowUndef);
14091424
}
14101425

1426+
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
1427+
const MachineRegisterInfo &MRI,
1428+
APInt SplatValue, bool AllowUndef) {
1429+
return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
1430+
AllowUndef);
1431+
}
1432+
14111433
std::optional<APInt>
14121434
llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) {
14131435
if (auto SplatValAndReg =

llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
634634
auto FortyTwo = B.buildConstant(LLT::scalar(64), 42);
635635
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42)));
636636
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123)));
637+
EXPECT_TRUE(
638+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42))));
639+
EXPECT_FALSE(
640+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123))));
637641

638642
// Test that this works inside of a more complex pattern.
639643
LLT s64 = LLT::scalar(64);
640644
auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo);
641645
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42)));
646+
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42))));
642647

643648
// Wrong constant.
644649
EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123)));
650+
EXPECT_FALSE(
651+
mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123))));
645652

646653
// No constant on the LHS.
647654
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
655+
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42))));
648656
}
649657

650658
TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
@@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
664672
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
665673
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
666674

675+
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
676+
m_SpecificICstSplat(APInt(64, 42))));
677+
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
678+
m_SpecificICstSplat(APInt(64, 43))));
679+
EXPECT_FALSE(
680+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42))));
681+
667682
MachineInstrBuilder NonConstantSplat =
668683
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
669684

@@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
673688
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
674689
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
675690

691+
EXPECT_TRUE(
692+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
693+
EXPECT_FALSE(
694+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43))));
695+
EXPECT_FALSE(
696+
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42))));
697+
676698
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
677699
EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
700+
EXPECT_FALSE(
701+
mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
678702
}
679703

680704
TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
@@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
695719
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
696720
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
697721

722+
EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
723+
m_SpecificICstOrSplat(APInt(64, 42))));
724+
EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
725+
m_SpecificICstOrSplat(APInt(64, 43))));
726+
EXPECT_TRUE(
727+
mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
728+
698729
MachineInstrBuilder NonConstantSplat =
699730
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
700731

@@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
704735
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
705736
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
706737

738+
EXPECT_TRUE(
739+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
740+
EXPECT_FALSE(
741+
mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43))));
742+
EXPECT_FALSE(
743+
mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
744+
707745
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
708746
EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
747+
EXPECT_TRUE(
748+
mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
709749
}
710750

711751
TEST_F(AArch64GISelMITest, MatchZeroInt) {

0 commit comments

Comments
 (0)