diff --git a/llvm/include/llvm/ADT/StringSwitch.h b/llvm/include/llvm/ADT/StringSwitch.h index 7093da07663a0..86e591c71c92e 100644 --- a/llvm/include/llvm/ADT/StringSwitch.h +++ b/llvm/include/llvm/ADT/StringSwitch.h @@ -14,7 +14,6 @@ #define LLVM_ADT_STRINGSWITCH_H #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Compiler.h" #include #include #include @@ -67,9 +66,7 @@ class StringSwitch { // Case-sensitive case matchers StringSwitch &Case(StringLiteral S, T Value) { - if (!Result && Str == S) { - Result = std::move(Value); - } + CaseImpl(Value, S); return *this; } @@ -88,61 +85,59 @@ class StringSwitch { } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, T Value) { - return Case(S0, Value).Case(S1, Value); + return CasesImpl(Value, S0, S1); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, T Value) { - return Case(S0, Value).Cases(S1, S2, Value); + return CasesImpl(Value, S0, S1, S2); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, Value); + return CasesImpl(Value, S0, S1, S2, S3); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, StringLiteral S5, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, S5, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4, S5); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, StringLiteral S5, StringLiteral S6, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, StringLiteral S5, StringLiteral S6, StringLiteral S7, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, StringLiteral S5, StringLiteral S6, StringLiteral S7, StringLiteral S8, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8); } StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, StringLiteral S5, StringLiteral S6, StringLiteral S7, StringLiteral S8, StringLiteral S9, T Value) { - return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, S9, Value); + return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8, S9); } // Case-insensitive case matchers. StringSwitch &CaseLower(StringLiteral S, T Value) { - if (!Result && Str.equals_insensitive(S)) - Result = std::move(Value); - + CaseLowerImpl(Value, S); return *this; } @@ -161,22 +156,22 @@ class StringSwitch { } StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, T Value) { - return CaseLower(S0, Value).CaseLower(S1, Value); + return CasesLowerImpl(Value, S0, S1); } StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, T Value) { - return CaseLower(S0, Value).CasesLower(S1, S2, Value); + return CasesLowerImpl(Value, S0, S1, S2); } StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, T Value) { - return CaseLower(S0, Value).CasesLower(S1, S2, S3, Value); + return CasesLowerImpl(Value, S0, S1, S2, S3); } StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2, StringLiteral S3, StringLiteral S4, T Value) { - return CaseLower(S0, Value).CasesLower(S1, S2, S3, S4, Value); + return CasesLowerImpl(Value, S0, S1, S2, S3, S4); } [[nodiscard]] R Default(T Value) { @@ -189,6 +184,39 @@ class StringSwitch { assert(Result && "Fell off the end of a string-switch"); return std::move(*Result); } + +private: + // Returns true when `Str` matches the `S` argument, and stores the result. + bool CaseImpl(T &Value, StringLiteral S) { + if (!Result && Str == S) { + Result = std::move(Value); + return true; + } + return false; + } + + // Returns true when `Str` matches the `S` argument (case-insensitive), and + // stores the result. + bool CaseLowerImpl(T &Value, StringLiteral S) { + if (!Result && Str.equals_insensitive(S)) { + Result = std::move(Value); + return true; + } + return false; + } + + template StringSwitch &CasesImpl(T &Value, Args... Cases) { + // Stop matching after the string is found. + (... || CaseImpl(Value, Cases)); + return *this; + } + + template + StringSwitch &CasesLowerImpl(T &Value, Args... Cases) { + // Stop matching after the string is found. + (... || CaseLowerImpl(Value, Cases)); + return *this; + } }; } // end namespace llvm diff --git a/llvm/unittests/ADT/StringSwitchTest.cpp b/llvm/unittests/ADT/StringSwitchTest.cpp index 2ce6cdca8d36a..2953f4b0a381b 100644 --- a/llvm/unittests/ADT/StringSwitchTest.cpp +++ b/llvm/unittests/ADT/StringSwitchTest.cpp @@ -205,3 +205,28 @@ TEST(StringSwitchTest, CasesLower) { EXPECT_EQ(OSType::Unknown, Translate("wind")); EXPECT_EQ(OSType::Unknown, Translate("")); } + +TEST(StringSwitchTest, CasesCopies) { + struct Copyable { + unsigned &NumCopies; + Copyable(unsigned &Value) : NumCopies(Value) {} + Copyable(const Copyable &Other) : NumCopies(Other.NumCopies) { + ++NumCopies; + } + Copyable &operator=(const Copyable &Other) { + ++NumCopies; + return *this; + } + }; + + // Check that evaluating multiple cases does not cause unnecessary copies. + unsigned NumCopies = 0; + llvm::StringSwitch("baz").Cases("foo", "bar", "baz", "qux", + Copyable{NumCopies}); + EXPECT_EQ(NumCopies, 1u); + + NumCopies = 0; + llvm::StringSwitch("baz").CasesLower( + "Foo", "Bar", "Baz", "Qux", Copyable{NumCopies}); + EXPECT_EQ(NumCopies, 1u); +}