Skip to content

Commit db934f7

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Modify get_data to error out on nan/inf (#2633)
Summary: Pull Request resolved: #2633 This method is leveraged by `StandardizeY`, `Winsorize`, `PowerTransformY`, and `PercentileY`. This change will improve the robustness of our transform layer to non-finite values and error out before we pass those down to BoTorch. Reviewed By: Balandat Differential Revision: D60681606 fbshipit-source-id: 6b423729a0d13cf6a833802a28a5645be78b1a08
1 parent 6775e8a commit db934f7

File tree

6 files changed

+81
-6
lines changed

6 files changed

+81
-6
lines changed

ax/modelbridge/transforms/tests/test_percentile_y_transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ax.exceptions.core import DataRequiredError
1414
from ax.modelbridge.transforms.percentile_y import PercentileY
1515
from ax.utils.common.testutils import TestCase
16+
from ax.utils.testing.core_stubs import get_observations_with_invalid_value
1617

1718

1819
class PercentileYTransformTest(TestCase):
@@ -126,3 +127,13 @@ def test_TransformObservationsWithWinsorization(self) -> None:
126127
np.allclose(mean_results, expected),
127128
msg=f"Unexpected mean Results: {mean_results}. Expected: {expected}.",
128129
)
130+
131+
def test_non_finite_data_raises(self) -> None:
132+
for invalid_value in [float("nan"), float("inf")]:
133+
observations = get_observations_with_invalid_value(
134+
invalid_value=invalid_value
135+
)
136+
with self.assertRaisesRegex(
137+
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
138+
):
139+
PercentileY(observations=observations, config={"metrics": ["m1"]})

ax/modelbridge/transforms/tests/test_power_y_transform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
2828
from ax.utils.common.testutils import TestCase
29+
from ax.utils.testing.core_stubs import get_observations_with_invalid_value
2930
from sklearn.preprocessing import PowerTransformer
3031

3132

@@ -328,3 +329,11 @@ def test_TransformOptimizationConfig(self) -> None:
328329
"that are part of a ScalarizedOutcomeConstraint.",
329330
str(cm.exception),
330331
)
332+
333+
def test_non_finite_data_raises(self) -> None:
334+
for invalid_value in [float("nan"), float("inf")]:
335+
observations = get_observations_with_invalid_value(invalid_value)
336+
with self.assertRaisesRegex(
337+
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
338+
):
339+
PowerTransformY(observations=observations, config={"metrics": ["m1"]})

ax/modelbridge/transforms/tests/test_standardize_y_transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ax.exceptions.core import DataRequiredError
2121
from ax.modelbridge.transforms.standardize_y import StandardizeY
2222
from ax.utils.common.testutils import TestCase
23+
from ax.utils.testing.core_stubs import get_observations_with_invalid_value
2324

2425

2526
class StandardizeYTransformTest(TestCase):
@@ -153,6 +154,16 @@ def test_TransformOptimizationConfig(self) -> None:
153154
with self.assertRaises(ValueError):
154155
oc = self.t.transform_optimization_config(oc, None, None)
155156

157+
def test_non_finite_data_raises(self) -> None:
158+
for invalid_value in [float("nan"), float("inf")]:
159+
observations = get_observations_with_invalid_value(
160+
invalid_value=invalid_value
161+
)
162+
with self.assertRaisesRegex(
163+
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
164+
):
165+
StandardizeY(observations=observations, config={"metrics": ["m1"]})
166+
156167

157168
def osd_allclose(osd1: ObservationData, osd2: ObservationData) -> bool:
158169
if osd1.metric_names != osd2.metric_names:

ax/modelbridge/transforms/tests/test_winsorize_transform.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import warnings
1010
from copy import deepcopy
11-
from typing import Dict, Optional, Tuple
11+
from typing import Any, Dict, Optional, Tuple
1212
from unittest import mock
1313

1414
import numpy as np
@@ -40,7 +40,10 @@
4040
)
4141
from ax.models.winsorization_config import WinsorizationConfig
4242
from ax.utils.common.testutils import TestCase
43-
from ax.utils.testing.core_stubs import get_optimization_config
43+
from ax.utils.testing.core_stubs import (
44+
get_observations_with_invalid_value,
45+
get_optimization_config,
46+
)
4447
from typing_extensions import SupportsIndex
4548

4649
INF = float("inf")
@@ -642,6 +645,19 @@ def test_relative_constraints(
642645
)
643646
self.assertDictEqual(t.cutoffs, {"a": (-INF, 3.5), "b": (-INF, 12.0)})
644647

648+
def test_non_finite_data_raises(self) -> None:
649+
for invalid_value in [float("nan"), float("inf")]:
650+
observations = get_observations_with_invalid_value(
651+
invalid_value=invalid_value
652+
)
653+
config: Dict[str, Any] = {
654+
"winsorization_config": WinsorizationConfig(upper_quantile_margin=0.2)
655+
}
656+
with self.assertRaisesRegex(
657+
ValueError, f"Non-finite data found for metric m1: {invalid_value}"
658+
):
659+
Winsorize(search_space=None, observations=observations, config=config)
660+
645661

646662
# pyre-fixme[2]: Parameter must be annotated.
647663
def get_transform(observation_data, config=None, optimization_config=None) -> Winsorize:

ax/modelbridge/transforms/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,30 @@ def __getitem__(self, key: Number) -> Any:
6666

6767

6868
def get_data(
69-
observation_data: List[ObservationData], metric_names: Union[List[str], None] = None
69+
observation_data: List[ObservationData],
70+
metric_names: Union[List[str], None] = None,
71+
raise_on_non_finite_data: bool = True,
7072
) -> Dict[str, List[float]]:
71-
"""Extract all metrics if `metric_names` is None."""
73+
"""Extract all metrics if `metric_names` is None.
74+
75+
Raises a value error if any data is non-finite.
76+
77+
Args:
78+
observation_data: List of observation data.
79+
metric_names: List of metric names.
80+
raise_on_non_finite_data: If true, raises an exception on nan/inf.
81+
82+
Returns:
83+
A dictionary mapping metric names to lists of metric values.
84+
"""
7285
Ys = defaultdict(list)
7386
for obsd in observation_data:
7487
for i, m in enumerate(obsd.metric_names):
7588
if metric_names is None or m in metric_names:
76-
Ys[m].append(obsd.means[i])
89+
val = obsd.means[i]
90+
if raise_on_non_finite_data and (not np.isfinite(val)):
91+
raise ValueError(f"Non-finite data found for metric {m}: {val}")
92+
Ys[m].append(val)
7793
return Ys
7894

7995

ax/utils/testing/core_stubs.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ax.core.metric import Metric
4242
from ax.core.multi_type_experiment import MultiTypeExperiment
4343
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
44-
from ax.core.observation import ObservationFeatures
44+
from ax.core.observation import Observation, ObservationData, ObservationFeatures
4545
from ax.core.optimization_config import (
4646
MultiObjectiveOptimizationConfig,
4747
OptimizationConfig,
@@ -1987,6 +1987,18 @@ def get_map_data(trial_index: int = 0) -> MapData:
19871987
)
19881988

19891989

1990+
def get_observations_with_invalid_value(invalid_value: float) -> List[Observation]:
1991+
obsd_with_non_finite = ObservationData(
1992+
metric_names=["m1"] * 4,
1993+
means=np.array([-100, 4, invalid_value, 2]),
1994+
covariance=np.eye(4),
1995+
)
1996+
observations = [
1997+
Observation(features=ObservationFeatures({}), data=obsd_with_non_finite)
1998+
]
1999+
return observations
2000+
2001+
19902002
# pyre-fixme[24]: Generic type `MapKeyInfo` expects 1 type parameter.
19912003
def get_map_key_info() -> MapKeyInfo:
19922004
return MapKeyInfo(key="epoch", default_value=0.0)

0 commit comments

Comments
 (0)