Skip to content

Commit 46464a5

Browse files
mpolson64facebook-github-bot
authored andcommitted
Check for Nans and Infs in TensorboardMetric (#2628)
Summary: Pull Request resolved: #2628 Raises a ValueError in bulk_fetch_trial_data if a Nan or an Inf is found. This will get wrapped up in a MetricFetchE and handled appropriately in the Scheduler (ex. INFO if we intend to try and fetch again, WARN if coming from a tracking metric, mark trial as ABANDONED if the metric is needed for the optimization https://fburl.com/code/eq37gghi). Reviewed By: Balandat Differential Revision: D60670356 fbshipit-source-id: 7f011f87c9ade9f1bf8b11a9ad2f2c34857bfd20
1 parent 8587587 commit 46464a5

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

ax/metrics/tensorboard.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from logging import Logger
1414
from typing import Any, Dict, List, Optional
1515

16+
import numpy as np
17+
1618
import pandas as pd
1719
from ax.core.base_trial import BaseTrial
1820
from ax.core.map_data import MapData, MapKeyInfo
@@ -166,6 +168,10 @@ def bulk_fetch_trial_data(
166168
.reset_index()
167169
)
168170

171+
# If there are any NaNs or Infs in the data, raise an Exception
172+
if np.any(~np.isfinite(df["mean"])):
173+
raise ValueError("Found NaNs or Infs in data")
174+
169175
# Apply per-metric post-processing
170176
# Apply cumulative "best" (min if lower_is_better)
171177
if metric.cumulative_best:

ax/metrics/tests/test_tensorboard.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from typing import List, Sequence
1111
from unittest import mock
1212

13+
import numpy as np
14+
1315
import pandas as pd
1416
from ax.core.map_data import MapData
17+
from ax.core.metric import MetricFetchE
1518
from ax.metrics.tensorboard import TensorboardMetric
1619
from ax.utils.common.testutils import TestCase
1720
from ax.utils.testing.core_stubs import get_trial
@@ -82,6 +85,61 @@ def test_fetch_trial_data(self) -> None:
8285

8386
self.assertTrue(df.equals(expected_df))
8487

88+
def test_fetch_trial_data_with_bad_data(self) -> None:
89+
nan_data = [1, 2, np.nan, 4]
90+
nan_multiplexer = _get_fake_multiplexer(fake_data=nan_data)
91+
92+
with mock.patch.object(
93+
TensorboardMetric,
94+
"_get_event_multiplexer_for_trial",
95+
return_value=nan_multiplexer,
96+
):
97+
metric = TensorboardMetric(
98+
name="loss",
99+
tag="loss",
100+
)
101+
102+
trial = get_trial()
103+
104+
result = metric.fetch_trial_data(trial=trial)
105+
106+
err = assert_is_instance(result.unwrap_err(), MetricFetchE)
107+
self.assertEqual(
108+
err.message,
109+
"Failed to fetch data for loss",
110+
)
111+
self.assertEqual(
112+
str(err.exception),
113+
"Found NaNs or Infs in data",
114+
)
115+
116+
inf_data = [1, 2, np.inf, 4]
117+
inf_multiplexer = _get_fake_multiplexer(fake_data=inf_data)
118+
119+
with mock.patch.object(
120+
TensorboardMetric,
121+
"_get_event_multiplexer_for_trial",
122+
return_value=inf_multiplexer,
123+
):
124+
metric = TensorboardMetric(
125+
name="loss",
126+
tag="loss",
127+
)
128+
129+
trial = get_trial()
130+
131+
result = metric.fetch_trial_data(trial=trial)
132+
133+
err = assert_is_instance(result.unwrap_err(), MetricFetchE)
134+
self.assertEqual(
135+
err.message,
136+
"Failed to fetch data for loss",
137+
)
138+
self.assertEqual(
139+
str(err.exception),
140+
"Found NaNs or Infs in data",
141+
)
142+
85143
def test_smoothing(self) -> None:
86144
fake_data = [8.0, 4.0, 2.0, 1.0]
87145
smoothing = 0.5

0 commit comments

Comments
 (0)