Skip to content

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Feb 1, 2025

Optimize the .Cases and .CasesLower functions to avoid needlessly recursing on each case and copying the associated values. We can instead take Value by reference and short-circuit by using the || operator.

Note that while the implementation uses variadic templates, we cannot simplify the public functions in the same way. This is because the current API forces the arguments to be converted to StringLiterals and places the Value parameter at the very end. Even if we did some tricks like split the parameter pack to separate out the Value, I do not see how we could force conversion to StringLiteral.

Optimize the `.Cases` and `.CasesLower` functions to avoid needlessly
recuring on each case and copying the associated values. We can instead
take `Value` by reference and short-circuit by using the `||` operator.

Note that while the implementation uses variadic templates, we cannot
simplify the public functions in the same way. This is because the
current API forces the arguments to be converted to `StringLiterals`,
and places the `Value` parameter at the very end. Even if we did some
tricks like split the parameter pack to separate out the `Value`, I do
not see how we could force conversion to `StringLiteral`.
@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2025

@llvm/pr-subscribers-llvm-adt

Author: Jakub Kuderski (kuhar)

Changes

Optimize the .Cases and .CasesLower functions to avoid needlessly recuring on each case and copying the associated values. We can instead take Value by reference and short-circuit by using the || operator.

Note that while the implementation uses variadic templates, we cannot simplify the public functions in the same way. This is because the current API forces the arguments to be converted to StringLiterals, and places the Value parameter at the very end. Even if we did some tricks like split the parameter pack to separate out the Value, I do not see how we could force conversion to StringLiteral.


Full diff: https://github.com/llvm/llvm-project/pull/125362.diff

2 Files Affected:

  • (modified) llvm/include/llvm/ADT/StringSwitch.h (+48-20)
  • (modified) llvm/unittests/ADT/StringSwitchTest.cpp (+25)
diff --git a/llvm/include/llvm/ADT/StringSwitch.h b/llvm/include/llvm/ADT/StringSwitch.h
index 7093da07663a0c..86e591c71c92ec 100644
--- a/llvm/include/llvm/ADT/StringSwitch.h
+++ b/llvm/include/llvm/ADT/StringSwitch.h
@@ -14,7 +14,6 @@
 #define LLVM_ADT_STRINGSWITCH_H
 
 #include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Compiler.h"
 #include <cassert>
 #include <cstring>
 #include <optional>
@@ -67,9 +66,7 @@ class StringSwitch {
 
   // Case-sensitive case matchers
   StringSwitch &Case(StringLiteral S, T Value) {
-    if (!Result && Str == S) {
-      Result = std::move(Value);
-    }
+    CaseImpl(Value, S);
     return *this;
   }
 
@@ -88,61 +85,59 @@ class StringSwitch {
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, T Value) {
-    return Case(S0, Value).Case(S1, Value);
+    return CasesImpl(Value, S0, S1);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       T Value) {
-    return Case(S0, Value).Cases(S1, S2, Value);
+    return CasesImpl(Value, S0, S1, S2);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, Value);
+    return CasesImpl(Value, S0, S1, S2, S3);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, StringLiteral S5,
                       T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, S5, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4, S5);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, StringLiteral S5,
                       StringLiteral S6, T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, StringLiteral S5,
                       StringLiteral S6, StringLiteral S7, T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, StringLiteral S5,
                       StringLiteral S6, StringLiteral S7, StringLiteral S8,
                       T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8);
   }
 
   StringSwitch &Cases(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                       StringLiteral S3, StringLiteral S4, StringLiteral S5,
                       StringLiteral S6, StringLiteral S7, StringLiteral S8,
                       StringLiteral S9, T Value) {
-    return Case(S0, Value).Cases(S1, S2, S3, S4, S5, S6, S7, S8, S9, Value);
+    return CasesImpl(Value, S0, S1, S2, S3, S4, S5, S6, S7, S8, S9);
   }
 
   // Case-insensitive case matchers.
   StringSwitch &CaseLower(StringLiteral S, T Value) {
-    if (!Result && Str.equals_insensitive(S))
-      Result = std::move(Value);
-
+    CaseLowerImpl(Value, S);
     return *this;
   }
 
@@ -161,22 +156,22 @@ class StringSwitch {
   }
 
   StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, T Value) {
-    return CaseLower(S0, Value).CaseLower(S1, Value);
+    return CasesLowerImpl(Value, S0, S1);
   }
 
   StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                            T Value) {
-    return CaseLower(S0, Value).CasesLower(S1, S2, Value);
+    return CasesLowerImpl(Value, S0, S1, S2);
   }
 
   StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                            StringLiteral S3, T Value) {
-    return CaseLower(S0, Value).CasesLower(S1, S2, S3, Value);
+    return CasesLowerImpl(Value, S0, S1, S2, S3);
   }
 
   StringSwitch &CasesLower(StringLiteral S0, StringLiteral S1, StringLiteral S2,
                            StringLiteral S3, StringLiteral S4, T Value) {
-    return CaseLower(S0, Value).CasesLower(S1, S2, S3, S4, Value);
+    return CasesLowerImpl(Value, S0, S1, S2, S3, S4);
   }
 
   [[nodiscard]] R Default(T Value) {
@@ -189,6 +184,39 @@ class StringSwitch {
     assert(Result && "Fell off the end of a string-switch");
     return std::move(*Result);
   }
+
+private:
+  // Returns true when `Str` matches the `S` argument, and stores the result.
+  bool CaseImpl(T &Value, StringLiteral S) {
+    if (!Result && Str == S) {
+      Result = std::move(Value);
+      return true;
+    }
+    return false;
+  }
+
+  // Returns true when `Str` matches the `S` argument (case-insensitive), and
+  // stores the result.
+  bool CaseLowerImpl(T &Value, StringLiteral S) {
+    if (!Result && Str.equals_insensitive(S)) {
+      Result = std::move(Value);
+      return true;
+    }
+    return false;
+  }
+
+  template <typename... Args> StringSwitch &CasesImpl(T &Value, Args... Cases) {
+    // Stop matching after the string is found.
+    (... || CaseImpl(Value, Cases));
+    return *this;
+  }
+
+  template <typename... Args>
+  StringSwitch &CasesLowerImpl(T &Value, Args... Cases) {
+    // Stop matching after the string is found.
+    (... || CaseLowerImpl(Value, Cases));
+    return *this;
+  }
 };
 
 } // end namespace llvm
diff --git a/llvm/unittests/ADT/StringSwitchTest.cpp b/llvm/unittests/ADT/StringSwitchTest.cpp
index 2ce6cdca8d36a3..2953f4b0a381bf 100644
--- a/llvm/unittests/ADT/StringSwitchTest.cpp
+++ b/llvm/unittests/ADT/StringSwitchTest.cpp
@@ -205,3 +205,28 @@ TEST(StringSwitchTest, CasesLower) {
   EXPECT_EQ(OSType::Unknown, Translate("wind"));
   EXPECT_EQ(OSType::Unknown, Translate(""));
 }
+
+TEST(StringSwitchTest, CasesCopies) {
+  struct Copyable {
+    unsigned &NumCopies;
+    Copyable(unsigned &Value) : NumCopies(Value) {}
+    Copyable(const Copyable &Other) : NumCopies(Other.NumCopies) {
+      ++NumCopies;
+    }
+    Copyable &operator=(const Copyable &Other) {
+      ++NumCopies;
+      return *this;
+    }
+  };
+
+  // Check that evaluating multiple cases does not cause unnecessary copies.
+  unsigned NumCopies = 0;
+  llvm::StringSwitch<Copyable, void>("baz").Cases("foo", "bar", "baz", "qux",
+                                                  Copyable{NumCopies});
+  EXPECT_EQ(NumCopies, 1u);
+
+  NumCopies = 0;
+  llvm::StringSwitch<Copyable, void>("baz").CasesLower(
+      "Foo", "Bar", "Baz", "Qux", Copyable{NumCopies});
+  EXPECT_EQ(NumCopies, 1u);
+}

@kuhar kuhar merged commit 67696a1 into llvm:main Feb 3, 2025
8 of 10 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Optimize the `.Cases` and `.CasesLower` functions to avoid needlessly
recursing on each case and copying the associated values. We can instead
take `Value` by reference and short-circuit by using the `||` operator.

Note that while the implementation uses variadic templates, we cannot
simplify the public functions in the same way. This is because the
current API forces the arguments to be converted to `StringLiterals` and
places the `Value` parameter at the very end. Even if we did some tricks
like split the parameter pack to separate out the `Value`, I do not see
how we could force conversion to `StringLiteral`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants