Skip to content

Commit c2d8610

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Clear digits in UnitX and Logit transforms before updating bounds (#5071)
Summary: Pull Request resolved: #5071 When a RangeParameter has a `digits` value calibrated for its original scale (e.g., `digits=-3` to round to the nearest 1000 on a [5000, 500000] range), transforms that change the parameter's scale must clear `digits` before calling `update_range`. Otherwise, `RangeParameter.cast()` applies the original-scale rounding to the transformed bounds, which can collapse them (e.g., `round(1.0, -3) == 0.0`, making both bounds 0.0 and raising `UserInputError`). The `Log` transform already handles this correctly (D66670173). This diff applies the same fix to `UnitX` and `Logit`, and documents the pattern in the base `Transform.transform_search_space` docstring. Reviewed By: saitcakmak Differential Revision: D97223066 fbshipit-source-id: d3c46b6963ad2ef594486cea97e4cb49e48bfbff
1 parent 78ef3b3 commit c2d8610

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

ax/adapter/transforms/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
109109
This is typically done in-place. This class implements the identity
110110
transform (does nothing).
111111
112+
NOTE for subclasses: If a transform changes the *scale* of a
113+
RangeParameter (e.g., Log, UnitX, Logit), it must clear ``digits``
114+
via ``p.set_digits(digits=None)`` before calling ``update_range``.
115+
Otherwise, rounding calibrated for the original scale will corrupt
116+
the transformed bounds (e.g., ``digits=-3`` rounds to the nearest
117+
1000, which collapses [0, 1] to 0). The Cast transform re-applies
118+
``digits`` in the original space during untransform.
119+
112120
Args:
113121
search_space: The search space
114122

ax/adapter/transforms/logit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def transform_observation_features(
6666
def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
6767
for p_name, p in search_space.parameters.items():
6868
if p_name in self.transform_parameters and isinstance(p, RangeParameter):
69+
# Don't round in logit space; digits will be re-applied in
70+
# the original space by the Cast transform during untransform.
71+
if p.digits is not None:
72+
p.set_digits(digits=None)
6973
p.set_logit_scale(False).update_range(
7074
lower=logit(p.lower).item(), upper=logit(p.upper).item()
7175
)

ax/adapter/transforms/tests/test_logit_transform.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,28 @@ def test_TransformSearchSpace(self) -> None:
122122
self.assertEqual(x_param.lower, logit(0.1))
123123
self.assertEqual(x_param.upper, logit(0.3))
124124

125+
def test_transform_search_space_clears_digits(self) -> None:
126+
"""Test that digits is cleared during transform to avoid rounding
127+
in logit space."""
128+
ss = SearchSpace(
129+
parameters=[
130+
RangeParameter(
131+
"x",
132+
lower=0.1,
133+
upper=0.9,
134+
parameter_type=ParameterType.FLOAT,
135+
logit_scale=True,
136+
digits=3,
137+
),
138+
]
139+
)
140+
t = Logit(search_space=ss)
141+
ss = t.transform_search_space(ss)
142+
x = assert_is_instance(ss.parameters["x"], RangeParameter)
143+
self.assertIsNone(x.digits)
144+
self.assertAlmostEqual(x.lower, logit(0.1))
145+
self.assertAlmostEqual(x.upper, logit(0.9))
146+
125147
def test_transform_experiment_data(self) -> None:
126148
parameterizations = [
127149
{"x": 0.2, "a": 1, "b": "a"},

ax/adapter/transforms/tests/test_unit_x_transform.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,29 @@ def test_TransformSearchSpace(self) -> None:
133133
self.search_space_with_target.parameters["x"].target_value, 1.0
134134
)
135135

136+
def test_transform_search_space_clears_digits(self) -> None:
137+
"""Test that digits is cleared during transform to avoid rounding
138+
in unit space. Regression test for a bug where digits=-3 (round to
139+
nearest 1000) collapsed [0, 1] bounds to (0.0, 0.0)."""
140+
ss = SearchSpace(
141+
parameters=[
142+
RangeParameter(
143+
"w",
144+
lower=5000.0,
145+
upper=500000.0,
146+
parameter_type=ParameterType.FLOAT,
147+
digits=-3,
148+
),
149+
]
150+
)
151+
t = UnitX(search_space=ss)
152+
ss = t.transform_search_space(ss)
153+
w = assert_is_instance(ss.parameters["w"], RangeParameter)
154+
# digits must be cleared so rounding doesn't corrupt [0, 1] bounds.
155+
self.assertIsNone(w.digits)
156+
self.assertEqual(w.lower, 0.0)
157+
self.assertEqual(w.upper, 1.0)
158+
136159
def test_TransformNewSearchSpace(self) -> None:
137160
new_ss = SearchSpace(
138161
parameters=[

ax/adapter/transforms/unit_x.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
7373
if (p_bounds := self.bounds.get(p_name)) is not None and isinstance(
7474
p, RangeParameter
7575
):
76+
# Don't round in unit space; digits will be re-applied in
77+
# the original space by the Cast transform during untransform.
78+
if p.digits is not None:
79+
p.set_digits(digits=None)
7680
p.update_range(
7781
lower=self._normalize_value(value=p.lower, bounds=p_bounds),
7882
upper=self._normalize_value(value=p.upper, bounds=p_bounds),

0 commit comments

Comments
 (0)