Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions llvm/include/llvm/ADT/StringSwitch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#define LLVM_ADT_STRINGSWITCH_H

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Compiler.h"
#include <cassert>
#include <cstring>
#include <optional>
Expand Down Expand Up @@ -67,9 +66,7 @@ class StringSwitch {

// Case-sensitive case matchers
StringSwitch &Case(StringLiteral S, T Value) {
if (!Result && Str == S) {
Result = std::move(Value);
}
CaseImpl(Value, S);
return *this;
}

Expand All @@ -88,61 +85,59 @@ class StringSwitch {
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, T Value) {
return Case(S0, Value).Case(S1, Value);
return CasesImpl(Value, S0, S1);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
T Value) {
return Case(S0, Value).Cases(S1, S2, Value);
return CasesImpl(Value, S0, S1, S2);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, T Value) {
return Case(S0, Value).Cases(S1, S2, S3, Value);
return CasesImpl(Value, S0, S1, S2, S3);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, StringLiteral S5,
T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4, S5);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, StringLiteral S5,
StringLiteral S6, T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, StringLiteral S5,
StringLiteral S6, StringLiteral S7, T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, StringLiteral S5,
StringLiteral S6, StringLiteral S7, StringLiteral S8,
T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8);
}

StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, StringLiteral S5,
StringLiteral S6, StringLiteral S7, StringLiteral S8,
StringLiteral S9, T Value) {
return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, S9, Value);
return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8, S9);
}

// Case-insensitive case matchers.
StringSwitch &CaseLower(StringLiteral S, T Value) {
if (!Result && Str.equals_insensitive(S))
Result = std::move(Value);

CaseLowerImpl(Value, S);
return *this;
}

Expand All @@ -161,22 +156,22 @@ class StringSwitch {
}

StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, T Value) {
return CaseLower(S0, Value).CaseLower(S1, Value);
return CasesLowerImpl(Value, S0, S1);
}

StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
T Value) {
return CaseLower(S0, Value).CasesLower(S1, S2, Value);
return CasesLowerImpl(Value, S0, S1, S2);
}

StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, T Value) {
return CaseLower(S0, Value).CasesLower(S1, S2, S3, Value);
return CasesLowerImpl(Value, S0, S1, S2, S3);
}

StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
StringLiteral S3, StringLiteral S4, T Value) {
return CaseLower(S0, Value).CasesLower(S1, S2, S3, S4, Value);
return CasesLowerImpl(Value, S0, S1, S2, S3, S4);
}

[[nodiscard]] R Default(T Value) {
Expand All @@ -189,6 +184,39 @@ class StringSwitch {
assert(Result && "Fell off the end of a string-switch");
return std::move(*Result);
}

private:
// Returns true when `Str` matches the `S` argument, and stores the result.
bool CaseImpl(T &Value, StringLiteral S) {
if (!Result && Str == S) {
Result = std::move(Value);
return true;
}
return false;
}

// Returns true when `Str` matches the `S` argument (case-insensitive), and
// stores the result.
bool CaseLowerImpl(T &Value, StringLiteral S) {
if (!Result && Str.equals_insensitive(S)) {
Result = std::move(Value);
return true;
}
return false;
}

template <typename... Args> StringSwitch &CasesImpl(T &Value, Args... Cases) {
// Stop matching after the string is found.
(... || CaseImpl(Value, Cases));
return *this;
}

template <typename... Args>
StringSwitch &CasesLowerImpl(T &Value, Args... Cases) {
// Stop matching after the string is found.
(... || CaseLowerImpl(Value, Cases));
return *this;
}
};

} // end namespace llvm
Expand Down
25 changes: 25 additions & 0 deletions llvm/unittests/ADT/StringSwitchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,28 @@ TEST(StringSwitchTest, CasesLower) {
EXPECT_EQ(OSType::Unknown, Translate("wind"));
EXPECT_EQ(OSType::Unknown, Translate(""));
}

TEST(StringSwitchTest, CasesCopies) {
struct Copyable {
unsigned &NumCopies;
Copyable(unsigned &Value) : NumCopies(Value) {}
Copyable(const Copyable &Other) : NumCopies(Other.NumCopies) {
++NumCopies;
}
Copyable &operator=(const Copyable &Other) {
++NumCopies;
return *this;
}
};

// Check that evaluating multiple cases does not cause unnecessary copies.
unsigned NumCopies = 0;
llvm::StringSwitch<Copyable, void>("baz").Cases("foo", "bar", "baz", "qux",
Copyable{NumCopies});
EXPECT_EQ(NumCopies, 1u);

NumCopies = 0;
llvm::StringSwitch<Copyable, void>("baz").CasesLower(
"Foo", "Bar", "Baz", "Qux", Copyable{NumCopies});
EXPECT_EQ(NumCopies, 1u);
}