|
11 | 11 | import logging
|
12 | 12 |
|
13 | 13 | from logging import Logger
|
14 |
| -from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, Union |
| 14 | +from typing import Any, Dict, List, Optional |
15 | 15 |
|
16 | 16 | import pandas as pd
|
17 | 17 | from ax.core.base_trial import BaseTrial
|
18 | 18 | from ax.core.map_data import MapData, MapKeyInfo
|
19 | 19 | from ax.core.map_metric import MapMetric
|
20 | 20 | from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
|
21 | 21 | from ax.core.trial import Trial
|
22 |
| -from ax.metrics.curve import AbstractCurveMetric |
23 | 22 | from ax.utils.common.logger import get_logger
|
24 | 23 | from ax.utils.common.result import Err, Ok
|
25 | 24 | from pyre_extensions import assert_is_instance
|
|
33 | 32 | from tensorboard.backend.event_processing import (
|
34 | 33 | plugin_event_multiplexer as event_multiplexer,
|
35 | 34 | )
|
36 |
| - from tensorboard.compat.proto import types_pb2 |
37 | 35 |
|
38 | 36 | logging.getLogger("tensorboard").setLevel(logging.CRITICAL)
|
39 | 37 |
|
@@ -218,120 +216,9 @@ def _get_event_multiplexer_for_trial(
|
218 | 216 |
|
219 | 217 | return mul
|
220 | 218 |
|
221 |
| - class TensorboardCurveMetric(AbstractCurveMetric): |
222 |
| - """A `CurveMetric` for getting Tensorboard curves.""" |
223 |
| - |
224 |
| - map_key_info: MapKeyInfo[float] = MapKeyInfo(key="steps", default_value=0.0) |
225 |
| - |
226 |
| - def get_curves_from_ids( |
227 |
| - self, |
228 |
| - ids: Iterable[Union[int, str]], |
229 |
| - names: Optional[Set[str]] = None, |
230 |
| - ) -> Dict[Union[int, str], Dict[str, pd.Series]]: |
231 |
| - """Get curve data from tensorboard logs. |
232 |
| -
|
233 |
| - NOTE: If the ids are not simple paths/posix locations, subclass this metric |
234 |
| - and replace this method with an appropriate one that retrieves the log |
235 |
| - results. |
236 |
| -
|
237 |
| - Args: |
238 |
| - ids: A list of string paths to tensorboard log directories. |
239 |
| - names: The names of the tags for which to fetch the curves. |
240 |
| - If omitted, all tags are returned. |
241 |
| -
|
242 |
| - Returns: |
243 |
| - A nested dictionary mapping ids (first level) and metric names (second |
244 |
| - level) to pandas Series of data. |
245 |
| - """ |
246 |
| - return {idx: get_tb_from_posix(path=str(idx), tags=names) for idx in ids} |
247 |
| - |
248 |
| - def get_tb_from_posix( |
249 |
| - path: str, |
250 |
| - tags: Optional[Set[str]] = None, |
251 |
| - ) -> Dict[str, pd.Series]: |
252 |
| - r"""Get Tensorboard data from a posix path. |
253 |
| -
|
254 |
| - Args: |
255 |
| - path: The posix path for the directory that contains the tensorboard logs. |
256 |
| - tags: The names of the tags for which to fetch the curves. If omitted, |
257 |
| - all tags are returned. |
258 |
| - Returns: |
259 |
| - A dictionary mapping tag names to pandas Series of data. |
260 |
| - """ |
261 |
| - logger.debug(f"Reading TB logs from {path}.") |
262 |
| - mul = event_multiplexer.EventMultiplexer(max_reload_threads=20) |
263 |
| - mul.AddRunsFromDirectory(path, None) |
264 |
| - mul.Reload() |
265 |
| - scalar_dict = mul.PluginRunToTagToContent("scalars") |
266 |
| - |
267 |
| - raw_result = [ |
268 |
| - {"tag": tag, "event": mul.Tensors(run, tag)} |
269 |
| - for run, run_dict in scalar_dict.items() |
270 |
| - for tag in run_dict |
271 |
| - if tags is None or tag in tags |
272 |
| - ] |
273 |
| - tb_run_data = {} |
274 |
| - for item in raw_result: |
275 |
| - latest_start_time = _get_latest_start_time(item["event"]) |
276 |
| - steps = [e.step for e in item["event"] if e.wall_time >= latest_start_time] |
277 |
| - vals = [ |
278 |
| - _get_event_value(e) |
279 |
| - for e in item["event"] |
280 |
| - if e.wall_time >= latest_start_time |
281 |
| - ] |
282 |
| - key = item["tag"] |
283 |
| - series = pd.Series(index=steps, data=vals).dropna() |
284 |
| - if key in tb_run_data: |
285 |
| - tb_run_data[key] = pd.concat(objs=[tb_run_data[key], series]) |
286 |
| - else: |
287 |
| - tb_run_data[key] = series |
288 |
| - for key, series in tb_run_data.items(): |
289 |
| - if any(series.index.duplicated()): |
290 |
| - # take average of repeated observations of the same "step" |
291 |
| - series = series.groupby(series.index).mean() |
292 |
| - logger.debug( |
293 |
| - f"Found duplicate steps for tag {key}. " |
294 |
| - "Removing duplicates by averaging." |
295 |
| - ) |
296 |
| - tb_run_data[key] = series |
297 |
| - return tb_run_data |
298 |
| - |
299 |
| - # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use |
300 |
| - # `typing.List` to avoid runtime subscripting errors. |
301 |
| - def _get_latest_start_time(events: List) -> float: |
302 |
| - """In each directory, there may be previous training runs due to restarting |
303 |
| - training jobs. |
304 |
| -
|
305 |
| - Args: |
306 |
| - events: A list of TensorEvents. |
307 |
| -
|
308 |
| - Returns: |
309 |
| - The start time of the latest training run. |
310 |
| - """ |
311 |
| - events.sort(key=lambda e: e.wall_time) |
312 |
| - start_time = events[0].wall_time |
313 |
| - for i in range(1, len(events)): |
314 |
| - # detect points in time where restarts occurred |
315 |
| - if events[i].step < events[i - 1].step: |
316 |
| - start_time = events[i].wall_time |
317 |
| - return start_time |
318 |
| - |
319 |
| - def _get_event_value(e: NamedTuple) -> float: |
320 |
| - r"""Helper function to check the dtype and then get the value |
321 |
| - stored in a TensorEvent.""" |
322 |
| - tensor = e.tensor_proto # pyre-ignore[16] |
323 |
| - if tensor.dtype == types_pb2.DT_FLOAT: |
324 |
| - return tensor.float_val[0] |
325 |
| - elif tensor.dtype == types_pb2.DT_DOUBLE: |
326 |
| - return tensor.double_val[0] |
327 |
| - elif tensor.dtype == types_pb2.DT_INT32: |
328 |
| - return tensor.int_val[0] |
329 |
| - else: |
330 |
| - raise ValueError(f"Tensorboard dtype {tensor.dtype} not supported.") |
331 |
| - |
332 | 219 | except ImportError:
|
333 | 220 | logger.warning(
|
334 | 221 | "tensorboard package not found. If you would like to use "
|
335 |
| - "TensorboardCurveMetric, please install tensorboard." |
| 222 | + "TensorboardMetric, please install tensorboard." |
336 | 223 | )
|
337 | 224 | pass
|
0 commit comments