diff --git a/llvm/include/llvm/Demangle/Demangle.h b/llvm/include/llvm/Demangle/Demangle.h index fe129603c0785..132e5088b5514 100644 --- a/llvm/include/llvm/Demangle/Demangle.h +++ b/llvm/include/llvm/Demangle/Demangle.h @@ -10,6 +10,7 @@ #define LLVM_DEMANGLE_DEMANGLE_H #include +#include #include #include @@ -54,6 +55,9 @@ enum MSDemangleFlags { char *microsoftDemangle(std::string_view mangled_name, size_t *n_read, int *status, MSDemangleFlags Flags = MSDF_None); +std::optional +getArm64ECInsertionPointInMangledName(std::string_view MangledName); + // Demangles a Rust v0 mangled symbol. char *rustDemangle(std::string_view MangledName); diff --git a/llvm/include/llvm/Demangle/MicrosoftDemangle.h b/llvm/include/llvm/Demangle/MicrosoftDemangle.h index 6891185a28e57..276efa7603690 100644 --- a/llvm/include/llvm/Demangle/MicrosoftDemangle.h +++ b/llvm/include/llvm/Demangle/MicrosoftDemangle.h @@ -9,6 +9,7 @@ #ifndef LLVM_DEMANGLE_MICROSOFTDEMANGLE_H #define LLVM_DEMANGLE_MICROSOFTDEMANGLE_H +#include "llvm/Demangle/Demangle.h" #include "llvm/Demangle/MicrosoftDemangleNodes.h" #include @@ -141,6 +142,9 @@ enum class FunctionIdentifierCodeGroup { Basic, Under, DoubleUnder }; // It has a set of functions to parse mangled symbols into Type instances. // It also has a set of functions to convert Type instances to strings. class Demangler { + friend std::optional + llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName); + public: Demangler() = default; virtual ~Demangler() = default; diff --git a/llvm/include/llvm/IR/Mangler.h b/llvm/include/llvm/IR/Mangler.h index f28ffc961b6db..33af40c5ae98d 100644 --- a/llvm/include/llvm/IR/Mangler.h +++ b/llvm/include/llvm/IR/Mangler.h @@ -56,6 +56,12 @@ void emitLinkerFlagsForUsedCOFF(raw_ostream &OS, const GlobalValue *GV, std::optional getArm64ECMangledFunctionName(StringRef Name); std::optional getArm64ECDemangledFunctionName(StringRef Name); +/// Check if an ARM64EC function name is mangled. +bool inline isArm64ECMangledFunctionName(StringRef Name) { + return Name[0] == '#' || + (Name[0] == '?' && Name.find("@$$h") != StringRef::npos); +} + } // End llvm namespace #endif diff --git a/llvm/lib/Demangle/MicrosoftDemangle.cpp b/llvm/lib/Demangle/MicrosoftDemangle.cpp index c5835e8c2e989..d35902a333767 100644 --- a/llvm/lib/Demangle/MicrosoftDemangle.cpp +++ b/llvm/lib/Demangle/MicrosoftDemangle.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -2424,6 +2425,24 @@ void Demangler::dumpBackReferences() { std::printf("\n"); } +std::optional +llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName) { + std::string_view ProcessedName{MangledName}; + + // We only support this for MSVC-style C++ symbols. + if (!consumeFront(ProcessedName, '?')) + return std::nullopt; + + // The insertion point is just after the name of the symbol, so parse that to + // remove it from the processed name. + Demangler D; + D.demangleFullyQualifiedSymbolName(ProcessedName); + if (D.Error) + return std::nullopt; + + return MangledName.length() - ProcessedName.length(); +} + char *llvm::microsoftDemangle(std::string_view MangledName, size_t *NMangled, int *Status, MSDemangleFlags Flags) { Demangler D; diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp index e6c3ea9d56883..884739b3212c6 100644 --- a/llvm/lib/IR/Mangler.cpp +++ b/llvm/lib/IR/Mangler.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" +#include "llvm/Demangle/Demangle.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -291,30 +292,25 @@ void llvm::emitLinkerFlagsForUsedCOFF(raw_ostream &OS, const GlobalValue *GV, } std::optional llvm::getArm64ECMangledFunctionName(StringRef Name) { - bool IsCppFn = Name[0] == '?'; - if (IsCppFn && Name.contains("$$h")) - return std::nullopt; - if (!IsCppFn && Name[0] == '#') + if (Name[0] != '?') { + // For non-C++ symbols, prefix the name with "#" unless it's already + // mangled. + if (Name[0] == '#') + return std::nullopt; + return std::optional(("#" + Name).str()); + } + + // If the name contains $$h, then it is already mangled. + if (Name.contains("$$h")) return std::nullopt; - StringRef Prefix = "$$h"; - size_t InsertIdx = 0; - if (IsCppFn) { - InsertIdx = Name.find("@@"); - size_t ThreeAtSignsIdx = Name.find("@@@"); - if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) { - InsertIdx += 2; - } else { - InsertIdx = Name.find("@"); - if (InsertIdx != std::string::npos) - InsertIdx++; - } - } else { - Prefix = "#"; - } + // Ask the demangler where we should insert "$$h". + auto InsertIdx = getArm64ECInsertionPointInMangledName(Name); + if (!InsertIdx) + return std::nullopt; return std::optional( - (Name.substr(0, InsertIdx) + Prefix + Name.substr(InsertIdx)).str()); + (Name.substr(0, *InsertIdx) + "$$h" + Name.substr(*InsertIdx)).str()); } std::optional diff --git a/llvm/unittests/IR/ManglerTest.cpp b/llvm/unittests/IR/ManglerTest.cpp index f2b78a1f98769..f8a3152564fd9 100644 --- a/llvm/unittests/IR/ManglerTest.cpp +++ b/llvm/unittests/IR/ManglerTest.cpp @@ -174,4 +174,81 @@ TEST(ManglerTest, GOFF) { "L#foo"); } +TEST(ManglerTest, Arm64EC) { + constexpr std::string_view Arm64ECNames[] = { + // Basic C name. + "#Foo", + + // Basic C++ name. + "?foo@@$$hYAHXZ", + + // Regression test: https://github.com/llvm/llvm-project/issues/115231 + "?GetValue@?$Wrapper@UA@@@@$$hQEBAHXZ", + + // Symbols from: + // ``` + // namespace A::B::C::D { + // struct Base { + // virtual int f() { return 0; } + // }; + // } + // struct Derived : public A::B::C::D::Base { + // virtual int f() override { return 1; } + // }; + // A::B::C::D::Base* MakeObj() { return new Derived(); } + // ``` + // void * __cdecl operator new(unsigned __int64) + "??2@$$hYAPEAX_K@Z", + // public: virtual int __cdecl A::B::C::D::Base::f(void) + "?f@Base@D@C@B@A@@$$hUEAAHXZ", + // public: __cdecl A::B::C::D::Base::Base(void) + "??0Base@D@C@B@A@@$$hQEAA@XZ", + // public: virtual int __cdecl Derived::f(void) + "?f@Derived@@$$hUEAAHXZ", + // public: __cdecl Derived::Derived(void) + "??0Derived@@$$hQEAA@XZ", + // struct A::B::C::D::Base * __cdecl MakeObj(void) + "?MakeObj@@$$hYAPEAUBase@D@C@B@A@@XZ", + + // Symbols from: + // ``` + // template struct WW { struct Z{}; }; + // template struct Wrapper { + // int GetValue(typename WW::Z) const; + // }; + // struct A { }; + // template int Wrapper::GetValue(typename WW::Z) const + // { return 3; } + // template class Wrapper; + // ``` + // public: int __cdecl Wrapper::GetValue(struct WW::Z)const + "?GetValue@?$Wrapper@UA@@@@$$hQEBAHUZ@?$WW@UA@@@@@Z", + }; + + for (const auto &Arm64ECName : Arm64ECNames) { + // Check that this is a mangled name. + EXPECT_TRUE(isArm64ECMangledFunctionName(Arm64ECName)) + << "Test case: " << Arm64ECName; + // Refuse to mangle it again. + EXPECT_FALSE(getArm64ECMangledFunctionName(Arm64ECName).has_value()) + << "Test case: " << Arm64ECName; + + // Demangle. + auto Arm64Name = getArm64ECDemangledFunctionName(Arm64ECName); + EXPECT_TRUE(Arm64Name.has_value()) << "Test case: " << Arm64ECName; + // Check that it is not mangled. + EXPECT_FALSE(isArm64ECMangledFunctionName(Arm64Name.value())) + << "Test case: " << Arm64ECName; + // Refuse to demangle it again. + EXPECT_FALSE(getArm64ECDemangledFunctionName(Arm64Name.value()).has_value()) + << "Test case: " << Arm64ECName; + + // Round-trip. + auto RoundTripArm64ECName = + getArm64ECMangledFunctionName(Arm64Name.value()); + EXPECT_EQ(RoundTripArm64ECName, Arm64ECName); + } +} + } // end anonymous namespace