From 66454feb05d9f498f1f8fb49c35a99a42808cfb7 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Wed, 30 Jul 2025 15:36:42 +0000 Subject: [PATCH 1/3] [GlobalISel] Add constant matcher for APInt --- .../llvm/CodeGen/GlobalISel/MIPatternMatch.h | 67 +++++++++++++------ llvm/include/llvm/CodeGen/GlobalISel/Utils.h | 12 ++++ llvm/lib/CodeGen/GlobalISel/Utils.cpp | 22 ++++++ .../CodeGen/GlobalISel/PatternMatchTest.cpp | 40 +++++++++++ 4 files changed, 122 insertions(+), 19 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index c0d3a12cbcb41..e8d9bc03f6428 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -192,24 +192,35 @@ m_GFCstOrSplat(std::optional &FPValReg) { /// Matcher for a specific constant value. struct SpecificConstantMatch { - int64_t RequestedVal; - SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {} + APInt RequestedVal; + SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { - int64_t MatchedVal; - return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal; + APInt MatchedVal; + if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { + if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth()) + RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth()); + else + MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth()); + + return APInt::isSameValue(MatchedVal, RequestedVal); + } + return false; } }; /// Matches a constant equal to \p RequestedValue. +inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) { + return SpecificConstantMatch(std::move(RequestedValue)); +} + inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) { - return SpecificConstantMatch(RequestedValue); + return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true)); } /// Matcher for a specific constant splat. struct SpecificConstantSplatMatch { - int64_t RequestedVal; - SpecificConstantSplatMatch(int64_t RequestedVal) - : RequestedVal(RequestedVal) {} + APInt RequestedVal; + SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, /* AllowUndef */ false); @@ -217,19 +228,31 @@ struct SpecificConstantSplatMatch { }; /// Matches a constant splat of \p RequestedValue. +inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) { + return SpecificConstantSplatMatch(std::move(RequestedValue)); +} + inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) { - return SpecificConstantSplatMatch(RequestedValue); + return SpecificConstantSplatMatch( + APInt(64, RequestedValue, /* isSigned */ true)); } /// Matcher for a specific constant or constant splat. struct SpecificConstantOrSplatMatch { - int64_t RequestedVal; - SpecificConstantOrSplatMatch(int64_t RequestedVal) + APInt RequestedVal; + SpecificConstantOrSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { - int64_t MatchedVal; - if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal) - return true; + APInt MatchedVal; + if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { + if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth()) + RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth()); + else + MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth()); + + if (APInt::isSameValue(MatchedVal, RequestedVal)) + return true; + } return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, /* AllowUndef */ false); } @@ -237,18 +260,24 @@ struct SpecificConstantOrSplatMatch { /// Matches a \p RequestedValue constant or a constant splat of \p /// RequestedValue. +inline SpecificConstantOrSplatMatch +m_SpecificICstOrSplat(APInt RequestedValue) { + return SpecificConstantOrSplatMatch(std::move(RequestedValue)); +} + inline SpecificConstantOrSplatMatch m_SpecificICstOrSplat(int64_t RequestedValue) { - return SpecificConstantOrSplatMatch(RequestedValue); + return SpecificConstantOrSplatMatch( + APInt(64, RequestedValue, /* isSigned */ true)); } -///{ /// Convenience matchers for specific integer values. -inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); } +inline SpecificConstantMatch m_ZeroInt() { + return SpecificConstantMatch(APInt(64, 0)); +} inline SpecificConstantMatch m_AllOnesInt() { - return SpecificConstantMatch(-1); + return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true)); } -///} /// Matcher for a specific register. struct SpecificRegisterMatch { diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index 66c960fe12c68..5c27605c26883 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef); +/// Return true if the specified register is defined by G_BUILD_VECTOR or +/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. +LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef); +/// Return true if the specified instruction is a G_BUILD_VECTOR or +/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. +LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI, diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index f48bfc06c14be..8955dd0370539 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1401,6 +1401,21 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg, return false; } +bool llvm::isBuildVectorConstantSplat(const Register Reg, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef) { + if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) { + if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth()) + return APInt::isSameValue( + SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue); + return APInt::isSameValue( + SplatValAndReg->Value, + SplatValue.sext(SplatValAndReg->Value.getBitWidth())); + } + + return false; +} + bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef) { @@ -1408,6 +1423,13 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, AllowUndef); } +bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef) { + return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue, + AllowUndef); +} + std::optional llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) { if (auto SplatValAndReg = diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index 25eb67e981588..1e0653b61e8f8 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) { auto FortyTwo = B.buildConstant(LLT::scalar(64), 42); EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42))); EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123))); + EXPECT_TRUE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123)))); // Test that this works inside of a more complex pattern. LLT s64 = LLT::scalar(64); auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42))); + EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42)))); // Wrong constant. EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123))); + EXPECT_FALSE( + mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123)))); // No constant on the LHS. EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42))); + EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { @@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43))); EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42))); + EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstSplat(APInt(64, 42)))); + EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + MachineInstrBuilder NonConstantSplat = B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); @@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43))); EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42))); + EXPECT_TRUE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42))); + EXPECT_FALSE( + mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { @@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43))); EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstOrSplat(APInt(64, 42)))); + EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstOrSplat(APInt(64, 43)))); + EXPECT_TRUE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + MachineInstrBuilder NonConstantSplat = B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); @@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43))); EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE( + mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchZeroInt) { From 57005c01a39dce83232e264dc45de5099563da93 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Aug 2025 14:05:00 +0000 Subject: [PATCH 2/3] Address Comments --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index e8d9bc03f6428..66f0a61ec316a 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -193,7 +193,7 @@ m_GFCstOrSplat(std::optional &FPValReg) { /// Matcher for a specific constant value. struct SpecificConstantMatch { APInt RequestedVal; - SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} + SpecificConstantMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { APInt MatchedVal; if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { @@ -220,7 +220,7 @@ inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) { /// Matcher for a specific constant splat. struct SpecificConstantSplatMatch { APInt RequestedVal; - SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} + SpecificConstantSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, /* AllowUndef */ false); @@ -240,7 +240,7 @@ inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) { /// Matcher for a specific constant or constant splat. struct SpecificConstantOrSplatMatch { APInt RequestedVal; - SpecificConstantOrSplatMatch(APInt RequestedVal) + SpecificConstantOrSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { APInt MatchedVal; @@ -273,10 +273,10 @@ m_SpecificICstOrSplat(int64_t RequestedValue) { /// Convenience matchers for specific integer values. inline SpecificConstantMatch m_ZeroInt() { - return SpecificConstantMatch(APInt(64, 0)); + return SpecificConstantMatch(APInt::getZero(64)); } inline SpecificConstantMatch m_AllOnesInt() { - return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true)); + return SpecificConstantMatch(APInt::getAllOnes(64)); } /// Matcher for a specific register. From 35455904b5673028c428d98553d4af5f65184a06 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 1 Aug 2025 14:08:43 +0000 Subject: [PATCH 3/3] formatting --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index 66f0a61ec316a..827cdbdb23c51 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -193,7 +193,8 @@ m_GFCstOrSplat(std::optional &FPValReg) { /// Matcher for a specific constant value. struct SpecificConstantMatch { APInt RequestedVal; - SpecificConstantMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} + SpecificConstantMatch(const APInt RequestedVal) + : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { APInt MatchedVal; if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { @@ -220,7 +221,8 @@ inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) { /// Matcher for a specific constant splat. struct SpecificConstantSplatMatch { APInt RequestedVal; - SpecificConstantSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} + SpecificConstantSplatMatch(const APInt RequestedVal) + : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, /* AllowUndef */ false);