diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h index c0d3a12cbcb41..827cdbdb23c51 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -192,23 +192,36 @@ m_GFCstOrSplat(std::optional &FPValReg) { /// Matcher for a specific constant value. struct SpecificConstantMatch { - int64_t RequestedVal; - SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {} + APInt RequestedVal; + SpecificConstantMatch(const APInt RequestedVal) + : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { - int64_t MatchedVal; - return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal; + APInt MatchedVal; + if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { + if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth()) + RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth()); + else + MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth()); + + return APInt::isSameValue(MatchedVal, RequestedVal); + } + return false; } }; /// Matches a constant equal to \p RequestedValue. +inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) { + return SpecificConstantMatch(std::move(RequestedValue)); +} + inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) { - return SpecificConstantMatch(RequestedValue); + return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true)); } /// Matcher for a specific constant splat. struct SpecificConstantSplatMatch { - int64_t RequestedVal; - SpecificConstantSplatMatch(int64_t RequestedVal) + APInt RequestedVal; + SpecificConstantSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, @@ -217,19 +230,31 @@ struct SpecificConstantSplatMatch { }; /// Matches a constant splat of \p RequestedValue. +inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) { + return SpecificConstantSplatMatch(std::move(RequestedValue)); +} + inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) { - return SpecificConstantSplatMatch(RequestedValue); + return SpecificConstantSplatMatch( + APInt(64, RequestedValue, /* isSigned */ true)); } /// Matcher for a specific constant or constant splat. struct SpecificConstantOrSplatMatch { - int64_t RequestedVal; - SpecificConstantOrSplatMatch(int64_t RequestedVal) + APInt RequestedVal; + SpecificConstantOrSplatMatch(const APInt RequestedVal) : RequestedVal(RequestedVal) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { - int64_t MatchedVal; - if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal) - return true; + APInt MatchedVal; + if (mi_match(Reg, MRI, m_ICst(MatchedVal))) { + if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth()) + RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth()); + else + MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth()); + + if (APInt::isSameValue(MatchedVal, RequestedVal)) + return true; + } return isBuildVectorConstantSplat(Reg, MRI, RequestedVal, /* AllowUndef */ false); } @@ -237,18 +262,24 @@ struct SpecificConstantOrSplatMatch { /// Matches a \p RequestedValue constant or a constant splat of \p /// RequestedValue. +inline SpecificConstantOrSplatMatch +m_SpecificICstOrSplat(APInt RequestedValue) { + return SpecificConstantOrSplatMatch(std::move(RequestedValue)); +} + inline SpecificConstantOrSplatMatch m_SpecificICstOrSplat(int64_t RequestedValue) { - return SpecificConstantOrSplatMatch(RequestedValue); + return SpecificConstantOrSplatMatch( + APInt(64, RequestedValue, /* isSigned */ true)); } -///{ /// Convenience matchers for specific integer values. -inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); } +inline SpecificConstantMatch m_ZeroInt() { + return SpecificConstantMatch(APInt::getZero(64)); +} inline SpecificConstantMatch m_AllOnesInt() { - return SpecificConstantMatch(-1); + return SpecificConstantMatch(APInt::getAllOnes(64)); } -///} /// Matcher for a specific register. struct SpecificRegisterMatch { diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index 66c960fe12c68..5c27605c26883 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef); +/// Return true if the specified register is defined by G_BUILD_VECTOR or +/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. +LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef); +/// Return true if the specified instruction is a G_BUILD_VECTOR or +/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef. +LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI, diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index f48bfc06c14be..8955dd0370539 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1401,6 +1401,21 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg, return false; } +bool llvm::isBuildVectorConstantSplat(const Register Reg, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef) { + if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) { + if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth()) + return APInt::isSameValue( + SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue); + return APInt::isSameValue( + SplatValAndReg->Value, + SplatValue.sext(SplatValAndReg->Value.getBitWidth())); + } + + return false; +} + bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI, int64_t SplatValue, bool AllowUndef) { @@ -1408,6 +1423,13 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, AllowUndef); } +bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + APInt SplatValue, bool AllowUndef) { + return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue, + AllowUndef); +} + std::optional llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) { if (auto SplatValAndReg = diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp index 25eb67e981588..1e0653b61e8f8 100644 --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) { auto FortyTwo = B.buildConstant(LLT::scalar(64), 42); EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42))); EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123))); + EXPECT_TRUE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123)))); // Test that this works inside of a more complex pattern. LLT s64 = LLT::scalar(64); auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42))); + EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42)))); // Wrong constant. EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123))); + EXPECT_FALSE( + mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123)))); // No constant on the LHS. EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42))); + EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { @@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43))); EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42))); + EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstSplat(APInt(64, 42)))); + EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + MachineInstrBuilder NonConstantSplat = B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); @@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) { EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43))); EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42))); + EXPECT_TRUE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42)))); + MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42))); + EXPECT_FALSE( + mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { @@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43))); EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstOrSplat(APInt(64, 42)))); + EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI, + m_SpecificICstOrSplat(APInt(64, 43)))); + EXPECT_TRUE( + mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + MachineInstrBuilder NonConstantSplat = B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); @@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) { EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43))); EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43)))); + EXPECT_FALSE( + mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); + MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo); EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42))); + EXPECT_TRUE( + mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42)))); } TEST_F(AArch64GISelMITest, MatchZeroInt) {