Skip to content

Commit 53f5590

Browse files
Merge pull request #85 from hubverse-org/mc/time-series-parquet-files/70
support time-series.parquet files in addition to .csv
2 parents c2a96b3 + db0ae81 commit 53f5590

File tree

5 files changed

+32
-41
lines changed

5 files changed

+32
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"structlog",
1919
"pyarrow",
2020
"jsonschema",
21-
"hubdata>=0.1.3",
21+
"hubdata>=0.2.0",
2222
]
2323

2424
[project.optional-dependencies]

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ click==8.2.1
1515
# hub-dashboard-predtimechart (pyproject.toml)
1616
# hubdata
1717
# pip-tools
18-
hubdata==0.1.3
18+
hubdata==0.2.0
1919
# via hub-dashboard-predtimechart (pyproject.toml)
2020
iniconfig==2.0.0
2121
# via pytest

src/hub_predtimechart/app/generate_target_json_files.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main(hub_dir, ptc_config_file, target_out_dir, regenerate):
5151
try:
5252
target_data_df = hub_config.get_target_data_df()
5353
except FileNotFoundError as error:
54-
logger.error(f"target data file not found. {hub_config.get_target_data_file_name()=}, {error=}")
54+
logger.error(f"target data file not found. {error=}")
5555
sys.exit(1)
5656

5757
json_files = _generate_target_json_files(hub_config, target_data_df, target_out_dir, regenerate)
@@ -131,7 +131,7 @@ def ptc_target_data(model_task: ModelTask, target_data_df: pl.DataFrame, task_id
131131
if max_as_of is None:
132132
return None
133133
else:
134-
target_data_df = target_data_df.filter(pl.col('as_of') == max_as_of.isoformat())
134+
target_data_df = target_data_df.filter(pl.col('as_of') == max_as_of)
135135
else:
136136
# the file is one that is assumed to be updated weekly and so we can
137137
# assume that the effective as_of date for this file is the same as
@@ -158,30 +158,30 @@ def ptc_target_data(model_task: ModelTask, target_data_df: pl.DataFrame, task_id
158158
if len(target_data_df) == 0:
159159
return None
160160

161+
# date column type depends on data source: date objects from `connect_target_data()`, strings from custom CSV files.
162+
# convert date objects to ISO strings for JSON serialization; pass through strings as-is.
161163
return {
162-
'date': target_data_df[target_date_col_name].to_list(),
164+
'date': [d.isoformat() if isinstance(d, date) else d for d in target_data_df[target_date_col_name].to_list()],
163165
'y': target_data_df[observation_col_name].to_list()
164166
}
165167

166168

167-
def _max_as_of_le_reference_date(target_data_df: pl.DataFrame, viz_target_id: str, reference_date: str) -> date:
169+
def _max_as_of_le_reference_date(target_data_df: pl.DataFrame, viz_target_id: str, reference_date: str) -> date | None:
168170
"""
169171
ptc_target_data() helper
170172
171-
:param target_data_df: a pl.DataFrame that loaded from HubConfigPtc.target_data_file_name. assumes follows our new
173+
:param target_data_df: a pl.DataFrame from connect_target_data(). assumes follows the
172174
time-series target data standard - has `as_of` column, etc.
173175
:param viz_target_id: the target of interest. via ModelTask.viz_target_id
174176
:param reference_date: string naming the reference_date of interest
175177
:return: max as_of that's <= `reference_date` for `viz_target_id`. return None if not found
176178
"""
177179
reference_date = date.fromisoformat(reference_date)
178-
unique_as_ofs = [date.fromisoformat(as_of) for as_of in pl.Series(target_data_df
179-
.filter(pl.col('target') == viz_target_id)
180-
.unique('as_of')
181-
.select('as_of')
182-
.sort('as_of'))] # sort for debugging
183-
le_as_ofs = [as_of for as_of in unique_as_ofs if as_of <= reference_date]
184-
return max(le_as_ofs) if le_as_ofs else None
180+
return (target_data_df
181+
.filter(pl.col('target') == viz_target_id)
182+
.filter(pl.col('as_of') <= reference_date)
183+
.select(pl.col('as_of').max())
184+
.item())
185185

186186

187187
#

src/hub_predtimechart/hub_config_ptc.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import polars as pl
99
import yaml
1010
from hubdata import HubConnection
11+
from hubdata.connect_target_data import TargetType, connect_target_data
1112
from jsonschema import FormatChecker, ValidationError, validate
1213

1314
from hub_predtimechart.ptc_schema import ptc_config_schema
@@ -117,10 +118,24 @@ def model_output_file_for_ref_date(self, model_id: str, reference_date: str) ->
117118

118119
def get_target_data_df(self) -> pl.DataFrame:
119120
"""
120-
Loads the target data csv file from the hub repo for now, file path for target data is hard coded to 'target-data'.
121-
Raises FileNotFoundError if target data file does not exist.
121+
Loads the target data file from the hub repo. Uses `hubdata.connect_target_data()` for standard target data
122+
locations (time-series.csv, time-series.parquet, or time-series/ directory). Falls back to custom file reading
123+
when `target_data_file_name` is specified in the predtimechart config.
124+
125+
:return: target data as a polars DataFrame
126+
:raises FileNotFoundError: if target data file does not exist
127+
:raises ValueError: if target data file has unsupported format (custom file only)
122128
"""
123-
target_data_file_path = self.hub_path / 'target-data' / self.get_target_data_file_name()
129+
# non-custom file name case: use `hubdata.connect_target_data()` for standard file locations
130+
if not self.target_data_file_name:
131+
try:
132+
target_conn = connect_target_data(self.hub_path, TargetType.TIME_SERIES)
133+
return pl.from_arrow(target_conn.to_table())
134+
except RuntimeError as error:
135+
raise FileNotFoundError(f"target data not found via hubdata. {error=}")
136+
137+
# custom file name case
138+
target_data_file_path = self.hub_path / 'target-data' / self.target_data_file_name
124139
try:
125140
# the override schema handles the 'US' location (the only location that doesn't parse as Int64)
126141
# todo hard-coded column names
@@ -132,13 +147,6 @@ def get_target_data_df(self) -> pl.DataFrame:
132147
raise FileNotFoundError(f"target data file not found. {target_data_file_path=}, {error=}")
133148

134149

135-
def get_target_data_file_name(self):
136-
"""
137-
:return: the target data file name under the "target-data" dir to use
138-
"""
139-
return self.target_data_file_name if self.target_data_file_name else 'time-series.csv'
140-
141-
142150
def _validate_predtimechart_config(ptc_config: dict, tasks: dict):
143151
"""
144152
Validates `ptc_config` against the schema in ptc_schema.py.

tests/hub_predtimechart/test_hub_config_ptc.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -215,23 +215,6 @@ def test_task_id_text_covid19_forecast_hub():
215215
}
216216

217217

218-
def test_get_target_data_file_name():
219-
# hub that predates the new target data standard file name. specifies file name in hub_config.target_data_file_name
220-
hub_path = Path('tests/hubs/covid19-forecast-hub')
221-
hub_config = HubConfigPtc(hub_path, hub_path / 'hub-config/predtimechart-config.yml')
222-
assert hub_config.get_target_data_file_name() == 'covid-hospital-admissions.csv'
223-
224-
# hub that predates the new target data standard file name. specifies file name in hub_config.target_data_file_name
225-
hub_path = Path('tests/hubs/FluSight-forecast-hub')
226-
hub_config = HubConfigPtc(hub_path, hub_path / 'hub-config/predtimechart-config.yml')
227-
assert hub_config.get_target_data_file_name() == 'target-hospital-admissions.csv'
228-
229-
# hub that uses the new target data standard file name: "target-data/time-series.csv"
230-
hub_path = Path('tests/hubs/flu-metrocast')
231-
hub_config = HubConfigPtc(hub_path, hub_path / 'hub-config/predtimechart-config.yml')
232-
assert hub_config.get_target_data_file_name() == 'time-series.csv'
233-
234-
235218
def test_model_tasks_instance_list():
236219
for hub_path, exp_len in [('tests/hubs/covid19-forecast-hub', 1),
237220
('tests/hubs/example-complex-forecast-hub', 1),

0 commit comments

Comments
 (0)