77import pandas as pd
88import polars as pl
99import yaml
10- from jsonschema import ValidationError , validate , FormatChecker
10+ from hubdata import HubConnection
11+ from jsonschema import FormatChecker , ValidationError , validate
1112
12- from hub_predtimechart .hub_config import HubConfig
1313from hub_predtimechart .ptc_schema import ptc_config_schema
1414
1515
16- class HubConfigPtc (HubConfig ):
16+ class HubConfigPtc (HubConnection ):
1717 """
18- A HubConfig subclass that adds various visualization-related variables from a hub.
18+ A `hubdata.HubConnection` subclass that adds various visualization-related variables from a hub. Note that this
19+ class only works with local filesystems, and therefore only accepts a Path for `hub_path`.
1920
2021 Instance variables:
22+
23+ Via HubConnection:
24+ - hub_path: str or Path pointing to a hub's root directory as passed to `hubdata.connect_hub()`
25+ - tasks: the hub's `tasks.json` contents as a dict
26+ - model_metadata_schema: "" `model-metadata-schema.json` ""
27+
28+ Via this class:
2129 - rounds_idx: as loaded from `ptc_config_file`
2230 - reference_date_col_name: ""
2331 - horizon_col_name: ""
@@ -36,13 +44,16 @@ class HubConfigPtc(HubConfig):
3644 """
3745
3846
39- def __init__ (self , hub_dir : Path , ptc_config_file : Path ):
47+ def __init__ (self , hub_path : Path , ptc_config_file : Path ):
4048 """
41- :param hub_dir: as defined in HubConfig.__init__()
49+ :param hub_path: Path pointing to a hub's root directory as passed to `hubdata.connect_hub()`
4250 :param ptc_config_file: location of `predtimechart-config.yml` (or other named) file that matches ptc_schema.py.
43- this file specifies how to process `hub_dir ` to get predtimechart output
51+ this file specifies how to process `hub_path ` to get predtimechart output
4452 """
45- super ().__init__ (hub_dir )
53+ if not isinstance (hub_path , Path ):
54+ raise TypeError (f"hub_path was not a Path. hub_path={ hub_path !r} , type={ type (hub_path ).__name__ } " )
55+
56+ super ().__init__ (hub_path )
4657
4758 if not ptc_config_file .exists ():
4859 raise RuntimeError (f"predtimechart config file not found: { ptc_config_file } " )
@@ -67,8 +78,8 @@ def __init__(self, hub_dir: Path, ptc_config_file: Path):
6778
6879 # set model_id_to_metadata
6980 self .model_id_to_metadata : dict [str , dict ] = {}
70- for model_metadata_file in (list ((self .hub_dir / 'model-metadata' ).glob ('*.yml' )) +
71- list ((self .hub_dir / 'model-metadata' ).glob ('*.yaml' ))):
81+ for model_metadata_file in (list ((self .hub_path / 'model-metadata' ).glob ('*.yml' )) +
82+ list ((self .hub_path / 'model-metadata' ).glob ('*.yaml' ))):
7283 with open (model_metadata_file ) as fp :
7384 model_metadata = yaml .safe_load (fp )
7485 model_id = f"{ model_metadata ['team_abbr' ]} -{ model_metadata ['model_abbr' ]} "
@@ -90,8 +101,8 @@ def model_output_file_for_ref_date(self, model_id: str, reference_date: str) ->
90101 Returns a Path to the model output file corresponding to `model_id` and `reference_date`. Returns None if none
91102 found.
92103 """
93- poss_output_files = [self .hub_dir / 'model-output' / model_id / f"{ reference_date } -{ model_id } .csv" ,
94- self .hub_dir / 'model-output' / model_id / f"{ reference_date } -{ model_id } .parquet" ]
104+ poss_output_files = [self .hub_path / 'model-output' / model_id / f"{ reference_date } -{ model_id } .csv" ,
105+ self .hub_path / 'model-output' / model_id / f"{ reference_date } -{ model_id } .parquet" ]
95106 for poss_output_file in poss_output_files :
96107 if poss_output_file .exists ():
97108 return poss_output_file
@@ -104,7 +115,7 @@ def get_target_data_df(self) -> pl.DataFrame:
104115 Loads the target data csv file from the hub repo for now, file path for target data is hard coded to 'target-data'.
105116 Raises FileNotFoundError if target data file does not exist.
106117 """
107- target_data_file_path = self .hub_dir / 'target-data' / self .get_target_data_file_name ()
118+ target_data_file_path = self .hub_path / 'target-data' / self .get_target_data_file_name ()
108119 try :
109120 # the override schema handles the 'US' location (the only location that doesn't parse as Int64)
110121 # todo hard-coded column names
@@ -165,8 +176,10 @@ def _validate_hub_ptc_compatibility(hub_config_ptc: HubConfigPtc):
165176 if not hub_config_ptc .model_tasks :
166177 raise ValidationError (f"no applicable model_task entries were found" )
167178
168- # validate: model metadata must contain a boolean `designated_model` field
169- if 'designated_model' not in hub_config_ptc .model_metadata_schema ['required' ]:
179+ # validate: model metadata must be present and must contain a boolean `designated_model` field
180+ if hub_config_ptc .model_metadata_schema is None :
181+ raise ValidationError (f"model metadata schema not found" )
182+ elif 'designated_model' not in hub_config_ptc .model_metadata_schema ['required' ]:
170183 raise ValidationError (f"'designated_model' not found in model metadata schema's 'required' section" )
171184
172185 # validate: all model_task entries have the same task_ids. frozenset lets us make a set of sets
0 commit comments