|
1 | 1 | import abc |
2 | 2 | import collections |
3 | 3 | import contextlib |
| 4 | +import dataclasses |
4 | 5 | import datetime |
5 | 6 | import json |
6 | 7 | import logging |
|
9 | 10 | import warnings |
10 | 11 | from pathlib import Path |
11 | 12 | from threading import Thread |
12 | | -from typing import Callable, Dict, List, NamedTuple, Optional, Union |
| 13 | +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union |
13 | 14 |
|
14 | 15 | import numpy |
15 | 16 | import pandas as pd |
@@ -104,6 +105,14 @@ def _start_job_default(row: pd.Series, connection: Connection, *args, **kwargs): |
104 | 105 | raise NotImplementedError("No 'start_job' callable provided") |
105 | 106 |
|
106 | 107 |
|
| 108 | +@dataclasses.dataclass(frozen=True) |
| 109 | +class _ColumnProperties: |
| 110 | + """Expected/required properties of a column in the job manager related dataframes""" |
| 111 | + |
| 112 | + dtype: str = "object" |
| 113 | + default: Any = None |
| 114 | + |
| 115 | + |
107 | 116 | class MultiBackendJobManager: |
108 | 117 | """ |
109 | 118 | Tracker for multiple jobs on multiple backends. |
@@ -171,6 +180,24 @@ def start_job( |
171 | 180 | Added ``cancel_running_job_after`` parameter. |
172 | 181 | """ |
173 | 182 |
|
| 183 | + # Expected columns in the job DB dataframes. |
| 184 | + # Mapping of column name to (dtype, default value) |
| 185 | + # TODO: make this part of public API when settled? |
| 186 | + _COLUMN_REQUIREMENTS: Mapping[str, _ColumnProperties] = { |
| 187 | + "id": _ColumnProperties(dtype="str"), |
| 188 | + "backend_name": _ColumnProperties(dtype="str"), |
| 189 | + "status": _ColumnProperties(dtype="str", default="not_started"), |
| 190 | + # TODO: use proper date/time dtype instead of leagacy str for start times? |
| 191 | + "start_time": _ColumnProperties(dtype="str"), |
| 192 | + "running_start_time": _ColumnProperties(dtype="str"), |
| 193 | + # TODO: these columns "cpu", "memory", "duration" are not referenced explicitly from MultiBackendJobManager, |
| 194 | + # but are indirectly coupled through handling of VITO-specific "usage" metadata in `_track_statuses`. |
| 195 | + # Since bfd99e34 they are not really required to be present anymore, can we make that more explicit? |
| 196 | + "cpu": _ColumnProperties(dtype="str"), |
| 197 | + "memory": _ColumnProperties(dtype="str"), |
| 198 | + "duration": _ColumnProperties(dtype="str"), |
| 199 | + } |
| 200 | + |
174 | 201 | def __init__( |
175 | 202 | self, |
176 | 203 | poll_sleep: int = 60, |
@@ -267,31 +294,16 @@ def _make_resilient(connection): |
267 | 294 | connection.session.mount("https://", HTTPAdapter(max_retries=retries)) |
268 | 295 | connection.session.mount("http://", HTTPAdapter(max_retries=retries)) |
269 | 296 |
|
270 | | - @staticmethod |
271 | | - def _normalize_df(df: pd.DataFrame) -> pd.DataFrame: |
| 297 | + @classmethod |
| 298 | + def _normalize_df(cls, df: pd.DataFrame) -> pd.DataFrame: |
272 | 299 | """ |
273 | 300 | Normalize given pandas dataframe (creating a new one): |
274 | 301 | ensure we have the required columns. |
275 | 302 |
|
276 | 303 | :param df: The dataframe to normalize. |
277 | 304 | :return: a new dataframe that is normalized. |
278 | 305 | """ |
279 | | - # check for some required columns. |
280 | | - required_with_default = [ |
281 | | - ("status", "not_started"), |
282 | | - ("id", None), |
283 | | - ("start_time", None), |
284 | | - ("running_start_time", None), |
285 | | - # TODO: columns "cpu", "memory", "duration" are not referenced directly |
286 | | - # within MultiBackendJobManager making it confusing to claim they are required. |
287 | | - # However, they are through assumptions about job "usage" metadata in `_track_statuses`. |
288 | | - # => proposed solution: allow to configure usage columns when adding a backend |
289 | | - ("cpu", None), |
290 | | - ("memory", None), |
291 | | - ("duration", None), |
292 | | - ("backend_name", None), |
293 | | - ] |
294 | | - new_columns = {col: val for (col, val) in required_with_default if col not in df.columns} |
| 306 | + new_columns = {col: req.default for (col, req) in cls._COLUMN_REQUIREMENTS.items() if col not in df.columns} |
295 | 307 | df = df.assign(**new_columns) |
296 | 308 |
|
297 | 309 | return df |
@@ -832,7 +844,10 @@ def _is_valid_wkt(self, wkt: str) -> bool: |
832 | 844 | return False |
833 | 845 |
|
834 | 846 | def read(self) -> pd.DataFrame: |
835 | | - df = pd.read_csv(self.path) |
| 847 | + df = pd.read_csv( |
| 848 | + self.path, |
| 849 | + dtype={c: r.dtype for (c, r) in MultiBackendJobManager._COLUMN_REQUIREMENTS.items()}, |
| 850 | + ) |
836 | 851 | if ( |
837 | 852 | "geometry" in df.columns |
838 | 853 | and df["geometry"].dtype.name != "geometry" |
|
0 commit comments