|
10 | 10 | from typing import List, Sequence
|
11 | 11 | from unittest import mock
|
12 | 12 |
|
| 13 | +import numpy as np |
| 14 | + |
13 | 15 | import pandas as pd
|
14 | 16 | from ax.core.map_data import MapData
|
| 17 | +from ax.core.metric import MetricFetchE |
15 | 18 | from ax.metrics.tensorboard import TensorboardMetric
|
16 | 19 | from ax.utils.common.testutils import TestCase
|
17 | 20 | from ax.utils.testing.core_stubs import get_trial
|
@@ -82,6 +85,61 @@ def test_fetch_trial_data(self) -> None:
|
82 | 85 |
|
83 | 86 | self.assertTrue(df.equals(expected_df))
|
84 | 87 |
|
| 88 | + def test_fetch_trial_data_with_bad_data(self) -> None: |
| 89 | + nan_data = [1, 2, np.nan, 4] |
| 90 | + nan_multiplexer = _get_fake_multiplexer(fake_data=nan_data) |
| 91 | + |
| 92 | + with mock.patch.object( |
| 93 | + TensorboardMetric, |
| 94 | + "_get_event_multiplexer_for_trial", |
| 95 | + return_value=nan_multiplexer, |
| 96 | + ): |
| 97 | + metric = TensorboardMetric( |
| 98 | + name="loss", |
| 99 | + tag="loss", |
| 100 | + ) |
| 101 | + |
| 102 | + trial = get_trial() |
| 103 | + |
| 104 | + result = metric.fetch_trial_data(trial=trial) |
| 105 | + |
| 106 | + err = assert_is_instance(result.unwrap_err(), MetricFetchE) |
| 107 | + self.assertEqual( |
| 108 | + err.message, |
| 109 | + "Failed to fetch data for loss", |
| 110 | + ) |
| 111 | + self.assertEqual( |
| 112 | + str(err.exception), |
| 113 | + "Found NaNs or Infs in data", |
| 114 | + ) |
| 115 | + |
| 116 | + inf_data = [1, 2, np.inf, 4] |
| 117 | + inf_multiplexer = _get_fake_multiplexer(fake_data=inf_data) |
| 118 | + |
| 119 | + with mock.patch.object( |
| 120 | + TensorboardMetric, |
| 121 | + "_get_event_multiplexer_for_trial", |
| 122 | + return_value=inf_multiplexer, |
| 123 | + ): |
| 124 | + metric = TensorboardMetric( |
| 125 | + name="loss", |
| 126 | + tag="loss", |
| 127 | + ) |
| 128 | + |
| 129 | + trial = get_trial() |
| 130 | + |
| 131 | + result = metric.fetch_trial_data(trial=trial) |
| 132 | + |
| 133 | + err = assert_is_instance(result.unwrap_err(), MetricFetchE) |
| 134 | + self.assertEqual( |
| 135 | + err.message, |
| 136 | + "Failed to fetch data for loss", |
| 137 | + ) |
| 138 | + self.assertEqual( |
| 139 | + str(err.exception), |
| 140 | + "Found NaNs or Infs in data", |
| 141 | + ) |
| 142 | + |
85 | 143 | def test_smoothing(self) -> None:
|
86 | 144 | fake_data = [8.0, 4.0, 2.0, 1.0]
|
87 | 145 | smoothing = 0.5
|
|
0 commit comments