Skip to content

Commit cdfc739

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add @ sanitization to string_utils for sympy compatibility (#5070)
Summary: Pull Request resolved: #5070 The recent migration of ScalarizedObjective to expression-based Objective.__init__() introduced sympy parsing of metric names. Metric names containing @ (e.g. metric@USCA) cause sympify() to interpret @ as Python's matrix multiplication operator (PEP 465), raising: TypeError: unsupported operand type(s) for @: 'Mul' and 'Symbol' Add @ -> __at__ sanitization to sanitize_name() and the corresponding reverse in unsanitize_name(), following the same pattern already used for `:, ., /, |, ~, -, ()`. Reviewed By: saitcakmak, Balandat Differential Revision: D97242521 fbshipit-source-id: 4b19b75ca8659cb209341322ebe886c3a344d7fa
1 parent c2d8610 commit cdfc739

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

ax/utils/common/string_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
TILDE_PLACEHOLDER = "__tilde__"
1818
SPACE_PLACEHOLDER = "__space__"
1919
HYPHEN_PLACEHOLDER = "__hyphen__"
20+
AT_PLACEHOLDER = "__at__"
2021
LPAREN_PLACEHOLDER = "__lparen__"
2122
RPAREN_PLACEHOLDER = "__rparen__"
2223
_forbidden_re: re.Pattern[str] = re.compile(r"[\;\[\'\\]")
@@ -127,7 +128,16 @@ def sanitize_name(s: str, sanitize_parens: bool = False) -> str:
127128
rf"\1{HYPHEN_PLACEHOLDER}",
128129
sans_tilde,
129130
)
130-
result = sans_hyphen
131+
# Replace "@" when it appears between identifier characters
132+
# (e.g. "metric@region" in metric names). In Python, @ is the matrix
133+
# multiplication operator (PEP 465), so SymPy's sympify() will
134+
# misinterpret it and raise a TypeError.
135+
sans_at = re.sub(
136+
r"([a-zA-Z_][a-zA-Z0-9_]*)@(?=[a-zA-Z0-9_])",
137+
rf"\1{AT_PLACEHOLDER}",
138+
sans_hyphen,
139+
)
140+
result = sans_at
131141

132142
# Optionally sanitize parentheses that are part of metric/parameter names.
133143
# Matches ``identifier(content)`` where content is purely [a-zA-Z0-9_].
@@ -158,7 +168,8 @@ def unsanitize_name(s: str) -> str:
158168
# Unsanitize in the reverse order of sanitization
159169
with_rparen = re.sub(rf"{RPAREN_PLACEHOLDER}", ")", s)
160170
with_lparen = re.sub(rf"{LPAREN_PLACEHOLDER}", "(", with_rparen)
161-
with_hyphen = re.sub(rf"{HYPHEN_PLACEHOLDER}", "-", with_lparen)
171+
with_at = re.sub(rf"{AT_PLACEHOLDER}", "@", with_lparen)
172+
with_hyphen = re.sub(rf"{HYPHEN_PLACEHOLDER}", "-", with_at)
162173
with_tilde = re.sub(rf"{TILDE_PLACEHOLDER}", "~", with_hyphen)
163174
with_pipe = re.sub(rf"{PIPE_PLACEHOLDER}", "|", with_tilde)
164175
with_colon = re.sub(rf"{COLON_PLACEHOLDER}", ":", with_pipe)

ax/utils/common/tests/test_string_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,26 @@ def test_sanitize_parens(self) -> None:
104104
"metric_(p50)",
105105
)
106106

107+
def test_sanitize_at_sign(self) -> None:
108+
"""Test that @ in metric names is sanitized to avoid sympy misparse."""
109+
self.assertEqual(
110+
sanitize_name("metric@region"),
111+
"metric__at__region",
112+
)
113+
self.assertEqual(
114+
sanitize_name("scope:sub:metric@region"),
115+
"scope__colon__sub__colon__metric__at__region",
116+
)
117+
107118
def test_unsanitize_name_roundtrip(self) -> None:
108119
"""Test that unsanitize_name reverses sanitize_name including parens."""
109120
names = [
110121
"foo.bar.baz",
111122
"foo.bar/11:Baz|qux",
112123
"~treatment_percent_",
113124
"metric-name",
125+
"metric@region",
126+
"scope:sub:metric@region",
114127
]
115128
for name in names:
116129
with self.subTest(name=name):

0 commit comments

Comments
 (0)