Skip to content

Commit a07f586

Browse files
authored
Merge pull request #79 from hubverse-org/fix/location-dtype-inference
Fix dtype inference for models with numeric-only location codes
2 parents 5917de1 + 27de9dc commit a07f586

File tree

9 files changed

+94
-10
lines changed

9 files changed

+94
-10
lines changed

src/hub_predtimechart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.2.4"
1+
__version__ = "2.2.5"

src/hub_predtimechart/app/generate_json_files.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import click
77
import pandas as pd
8+
import pyarrow.compute as pc
89
import structlog
910

1011
from hub_predtimechart.generate_data import forecast_data_for_model_df
@@ -86,13 +87,14 @@ def _generate_forecast_json_files(hub_config: HubConfigPtc, output_dir: Path, is
8687
for model_id in hub_config.model_id_to_metadata: # ex: ['Flusight-baseline', 'MOBS-GLEAM_FLUH', ...]
8788
model_output_file = hub_config.model_output_file_for_ref_date(model_id, reference_date)
8889
if model_output_file:
89-
if model_output_file.suffix == '.csv':
90-
model_id_to_df[model_id] = pd.read_csv(model_output_file, usecols=df_cols_to_use)
91-
elif model_output_file.suffix in ['.parquet', '.pqt']:
92-
model_id_to_df[model_id] = pd.read_parquet(model_output_file, columns=df_cols_to_use)
93-
else:
94-
raise RuntimeError(f"unsupported model output file type: {model_output_file!r}. "
95-
f"Only .csv and .parquet are supported")
90+
# Use hubdata's to_table() method with filtering to load only this model's data
91+
# for this reference_date. This applies the schema from tasks.json, ensuring
92+
# task_id columns (like location) are properly typed as strings, preventing
93+
# dtype inference issues with numeric-only values like "01", "02"
94+
filter_expr = (pc.field('model_id') == model_id) & \
95+
(pc.field(hub_config.reference_date_col_name) == date.fromisoformat(reference_date))
96+
pa_table = hub_config.to_table(columns=df_cols_to_use, filter=filter_expr)
97+
model_id_to_df[model_id] = pa_table.to_pandas()
9698

9799
if not model_id_to_df: # no model outputs for reference_date
98100
continue

tests/expected/example-complex-forecast-hub/forecasts/wk-inc-flu-hosp_01_2022-10-22.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,13 @@
2222
"q0.5": [118, 175],
2323
"q0.75": [133, 193],
2424
"q0.975": [165, 233]
25+
},
26+
"Test-NumericOnly": {
27+
"target_end_date": ["2022-10-29", "2022-11-05"],
28+
"q0.025": [1542, 1443],
29+
"q0.25": [1542, 1443],
30+
"q0.5": [1724, 1724],
31+
"q0.75": [1906, 2006],
32+
"q0.975": [2334, 2562]
2533
}
2634
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"Test-NumericOnly": {
3+
"target_end_date": [
4+
"2022-10-29",
5+
"2022-11-05"
6+
],
7+
"q0.025": [
8+
100.0,
9+
110.0
10+
],
11+
"q0.25": [
12+
150.0,
13+
160.0
14+
],
15+
"q0.5": [
16+
200.0,
17+
210.0
18+
],
19+
"q0.75": [
20+
250.0,
21+
260.0
22+
],
23+
"q0.975": [
24+
300.0,
25+
310.0
26+
]
27+
}
28+
}

tests/hub_predtimechart/test_generate_json_files.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def test_generate_forecast_json_files_ecfh(tmp_path):
1616
json_files = _generate_forecast_json_files(hub_config, output_dir)
1717
assert set(json_files) == {output_dir / 'wk-inc-flu-hosp_US_2022-10-22.json',
1818
output_dir / 'wk-inc-flu-hosp_01_2022-10-22.json',
19+
output_dir / 'wk-inc-flu-hosp_02_2022-10-22.json',
1920
output_dir / 'wk-inc-flu-hosp_US_2022-11-19.json',
2021
output_dir / 'wk-inc-flu-hosp_01_2022-11-19.json',
2122
output_dir / 'wk-inc-flu-hosp_US_2022-12-17.json',
@@ -28,6 +29,7 @@ def test_generate_forecast_json_files_ecfh(tmp_path):
2829
assert act_data == exp_data
2930

3031

32+
3133
def test_generate_forecast_json_files_flu_metrocast(tmp_path):
3234
"""
3335
An integration test of `generate_json_files.py`'s `_generate_json_files()` for flu-metrocast.
@@ -90,6 +92,7 @@ def test_generate_forecast_json_files_skip_files(tmp_path):
9092
json_files = Path(output_dir).glob("*")
9193
assert set(json_files) == {output_dir / 'wk-inc-flu-hosp_US_2022-10-22.json',
9294
output_dir / 'wk-inc-flu-hosp_01_2022-10-22.json',
95+
output_dir / 'wk-inc-flu-hosp_02_2022-10-22.json',
9396
output_dir / 'wk-inc-flu-hosp_US_2022-11-19.json',
9497
output_dir / 'wk-inc-flu-hosp_01_2022-11-19.json',
9598
output_dir / 'wk-inc-flu-hosp_01_2022-12-17.json'}
@@ -122,6 +125,7 @@ def test_generate_forecast_json_files_regenerate(tmp_path):
122125
json_files = _generate_forecast_json_files(hub_config, output_dir, True)
123126
assert set(json_files) == {output_dir / 'wk-inc-flu-hosp_US_2022-10-22.json',
124127
output_dir / 'wk-inc-flu-hosp_01_2022-10-22.json',
128+
output_dir / 'wk-inc-flu-hosp_02_2022-10-22.json',
125129
output_dir / 'wk-inc-flu-hosp_US_2022-11-19.json',
126130
output_dir / 'wk-inc-flu-hosp_01_2022-11-19.json',
127131
output_dir / 'wk-inc-flu-hosp_US_2022-12-17.json',

tests/hub_predtimechart/test_hub_config_ptc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_hub_config_complex_forecast_hub():
2121
assert hub_config.initial_checked_models == ['Flusight-baseline']
2222
assert hub_config.disclaimer == "Most forecasts have failed to reliably predict rapid changes in the trends of reported cases and hospitalizations. Due to this limitation, they should not be relied upon for decisions about the possibility or timing of rapid changes in trends."
2323
assert (sorted(list(hub_config.model_id_to_metadata.keys())) ==
24-
sorted(['Flusight-baseline', 'MOBS-GLEAM_FLUH', 'PSI-DICE']))
24+
sorted(['Flusight-baseline', 'MOBS-GLEAM_FLUH', 'PSI-DICE', 'Test-NumericOnly']))
2525
assert hub_config.target_data_file_name == 'covid-hospital-admissions.csv'
2626

2727
model_task_0 = hub_config.model_tasks[0] # only one

tests/hubs/example-complex-forecast-hub/hub-config/admin.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"owner": "Infectious-Disease-Modeling-Hubs",
1212
"repository": "example-complex-forecast-hub"
1313
},
14-
"file_format": ["csv"],
14+
"file_format": ["csv", "parquet"],
1515
"timezone": "US/Eastern",
1616
"model_output_dir": "model-output",
1717
"cloud": {
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
team_name: "Test"
2+
team_abbr: "Test"
3+
model_name: "Test model for numeric-only location codes"
4+
model_abbr: "NumericOnly"
5+
model_version: "1.0"
6+
model_contributors: [
7+
{
8+
"name": "Test User",
9+
"email": "test@example.com"
10+
}
11+
]
12+
website_url: "https://github.com/hubverse-org/hub-dashboard-predtimechart"
13+
license: "MIT"
14+
team_funding: "N/A"
15+
designated_model: false
16+
methods: "Test model for regression testing dtype inference with numeric-only location codes like '01', '02'."
17+
data_inputs: "None - test data"
18+
methods_long: "This is a test model used to verify that models with exclusively numeric location codes (e.g., '01', '02') are properly handled and not excluded from dashboard visualizations due to dtype inference issues."
19+
ensemble_of_models: false
20+
ensemble_of_hub_models: false
21+
source_notes: "Test data for issue #78"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
location,horizon,output_type_id,value,target_end_date,reference_date,output_type,target
2+
"01",1,0.025,1542,2022-10-29,2022-10-22,quantile,wk inc flu hosp
3+
"01",1,0.25,1542,2022-10-29,2022-10-22,quantile,wk inc flu hosp
4+
"01",1,0.5,1724,2022-10-29,2022-10-22,quantile,wk inc flu hosp
5+
"01",1,0.75,1906,2022-10-29,2022-10-22,quantile,wk inc flu hosp
6+
"01",1,0.975,2334,2022-10-29,2022-10-22,quantile,wk inc flu hosp
7+
"01",2,0.025,1443,2022-11-05,2022-10-22,quantile,wk inc flu hosp
8+
"01",2,0.25,1443,2022-11-05,2022-10-22,quantile,wk inc flu hosp
9+
"01",2,0.5,1724,2022-11-05,2022-10-22,quantile,wk inc flu hosp
10+
"01",2,0.75,2006,2022-11-05,2022-10-22,quantile,wk inc flu hosp
11+
"01",2,0.975,2562,2022-11-05,2022-10-22,quantile,wk inc flu hosp
12+
"02",1,0.025,100,2022-10-29,2022-10-22,quantile,wk inc flu hosp
13+
"02",1,0.25,150,2022-10-29,2022-10-22,quantile,wk inc flu hosp
14+
"02",1,0.5,200,2022-10-29,2022-10-22,quantile,wk inc flu hosp
15+
"02",1,0.75,250,2022-10-29,2022-10-22,quantile,wk inc flu hosp
16+
"02",1,0.975,300,2022-10-29,2022-10-22,quantile,wk inc flu hosp
17+
"02",2,0.025,110,2022-11-05,2022-10-22,quantile,wk inc flu hosp
18+
"02",2,0.25,160,2022-11-05,2022-10-22,quantile,wk inc flu hosp
19+
"02",2,0.5,210,2022-11-05,2022-10-22,quantile,wk inc flu hosp
20+
"02",2,0.75,260,2022-11-05,2022-10-22,quantile,wk inc flu hosp
21+
"02",2,0.975,310,2022-11-05,2022-10-22,quantile,wk inc flu hosp

0 commit comments

Comments
 (0)