55predictions.
66"""
77
8- from argparse import Namespace
98from datetime import datetime , timedelta
109from pathlib import Path
11- from typing import Any , TypedDict
10+ from typing import Any
1211
13- import joblib
1412import pandas as pd
15- import torch
16- import yaml
1713from fastapi import FastAPI , HTTPException
1814from fastapi .middleware .cors import CORSMiddleware
1915from pydantic import BaseModel
20- from sklearn .preprocessing import StandardScaler
2116
22- from gaca_ews .core .data_extraction import fetch_last_hours
23- from gaca_ews .core .preprocessing import preprocess_for_inference
24- from gaca_ews .model .gcngru import GCNGRU
17+ from gaca_ews .core .inference import InferenceEngine
2518
2619
2720app = FastAPI (
4033)
4134
4235
43- class Config :
44- """Configuration holder for model and paths."""
45-
46- def __init__ (self , config_path : str = "config.yaml" ) -> None :
47- """Initialize configuration from YAML file."""
48- with open (config_path , "r" ) as f :
49- cfg_dict = yaml .safe_load (f )
50-
51- # Convert to namespace-like object
52- def to_namespace (obj : Any ) -> Any :
53- if isinstance (obj , dict ):
54- ns = Namespace ()
55- for key , value in obj .items ():
56- setattr (ns , key , to_namespace (value ))
57- return ns
58- return obj
59-
60- self .config = to_namespace (cfg_dict )
61-
62-
6336class PredictionRequest (BaseModel ):
6437 """Request model for inference."""
6538
@@ -87,94 +60,11 @@ class ModelInfo(BaseModel):
8760 status : str
8861
8962
90- class ModelState (TypedDict , total = False ):
91- """Type definition for global model state."""
92-
93- model : GCNGRU
94- config : Namespace
95- artifacts_loaded : bool
96- feature_scaler : StandardScaler
97- target_scaler : StandardScaler
98- edge_index : torch .Tensor
99- edge_weight : torch .Tensor
100- nodes_df : pd .DataFrame
101- device : torch .device
102-
103-
104- # Global state
105- model_state : ModelState = {"artifacts_loaded" : False }
106-
107-
108- def load_model_artifacts () -> None :
109- """Load model and artifacts into memory."""
110- if model_state ["artifacts_loaded" ]:
111- return
112-
113- config = Config ()
114- cfg = config .config
115-
116- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
117-
118- # Load scalers
119- feature_scaler = joblib .load (cfg .model .feature_scaler_path )
120- target_scaler = joblib .load (cfg .model .target_scaler_path )
121-
122- # Load graph components
123- edge_index = torch .load (cfg .graph .edge_index_path , map_location = device )
124- edge_weight = torch .load (cfg .graph .edge_weight_path , map_location = device )
125- nodes_df = pd .read_csv (cfg .graph .nodes_csv_path )
126-
127- # Load model
128- checkpoint = torch .load (cfg .model .model_path , map_location = device )
129- model_class = checkpoint ["model_class" ]
130- model_args = checkpoint ["model_args" ]
131- raw_state = checkpoint ["model_state_dict" ]
132-
133- # Clean state dict
134- clean_state = {}
135- for k , v in raw_state .items ():
136- if k .startswith ("module." ):
137- clean_state [k .replace ("module." , "" , 1 )] = v
138- else :
139- clean_state [k ] = v
140-
141- # Adjust model args
142- if "pred_offsets" in model_args :
143- model_args ["pred_horizons" ] = len (model_args ["pred_offsets" ])
144- del model_args ["pred_offsets" ]
145- if "model_name" in model_args :
146- del model_args ["model_name" ]
147-
148- model_args ["num_nodes" ] = edge_index .max ().item () + 1
149-
150- if model_class == "DistributedDataParallel" :
151- model_class = cfg .model_arch
152-
153- if model_class == "GCNGRU" :
154- model = GCNGRU (** model_args )
155- else :
156- raise ValueError (f"Unknown model: { model_class } " )
157-
158- model .load_state_dict (clean_state )
159- model = model .to (device )
160- model .eval ()
161-
162- # Store in global state
163- model_state ["model" ] = model
164- model_state ["feature_scaler" ] = feature_scaler
165- model_state ["target_scaler" ] = target_scaler
166- model_state ["edge_index" ] = edge_index
167- model_state ["edge_weight" ] = edge_weight
168- model_state ["nodes_df" ] = nodes_df
169- model_state ["config" ] = cfg
170- model_state ["device" ] = device
171- model_state ["artifacts_loaded" ] = True
172-
173-
17463@app .on_event ("startup" )
17564async def startup_event () -> None :
17665 """Load model on startup."""
177- load_model_artifacts ()
66+ app .state .engine = InferenceEngine ("config.yaml" )
67+ app .state .engine .load_artifacts ()
17868
17969
18070@app .get ("/" )
@@ -190,93 +80,51 @@ async def root() -> dict[str, str]:
19080@app .get ("/health" )
19181async def health () -> dict [str , str | bool ]:
19282 """Health check endpoint."""
193- return {"status" : "healthy" , "model_loaded" : model_state [ "artifacts_loaded" ] }
83+ return {"status" : "healthy" , "model_loaded" : hasattr ( app . state , "engine" ) }
19484
19585
19686@app .get ("/model/info" , response_model = ModelInfo )
19787async def get_model_info () -> ModelInfo :
19888 """Get information about the loaded model."""
199- if not model_state [ "artifacts_loaded" ] :
89+ if not hasattr ( app . state , "engine" ) :
20090 raise HTTPException (status_code = 503 , detail = "Model not loaded" )
20191
202- cfg = model_state ["config" ]
203- nodes_df = model_state ["nodes_df" ]
204-
205- assert cfg is not None , "Config not loaded"
206- assert nodes_df is not None , "Nodes dataframe not loaded"
92+ info = app .state .engine .get_model_info ()
20793
20894 return ModelInfo (
209- model_architecture = cfg .model_arch ,
210- num_nodes = len (nodes_df ),
211- input_features = cfg .features ,
212- prediction_horizons = cfg .pred_offsets ,
213- region = {
214- "lat_min" : cfg .region .lat_min ,
215- "lat_max" : cfg .region .lat_max ,
216- "lon_min" : cfg .region .lon_min ,
217- "lon_max" : cfg .region .lon_max ,
218- },
95+ model_architecture = info ["model_architecture" ],
96+ num_nodes = info ["num_nodes" ],
97+ input_features = info ["input_features" ],
98+ prediction_horizons = info ["prediction_horizons" ],
99+ region = info ["region" ],
219100 status = "loaded" ,
220101 )
221102
222103
223104@app .post ("/predict" , response_model = list [PredictionResponse ])
224105async def run_inference (request : PredictionRequest ) -> list [PredictionResponse ]:
225106 """Run inference and return predictions."""
226- if not model_state [ "artifacts_loaded" ] :
107+ if not hasattr ( app . state , "engine" ) :
227108 raise HTTPException (status_code = 503 , detail = "Model not loaded" )
228109
229110 try :
230- # Fetch data
231- cfg = model_state ["config" ]
232- assert cfg is not None , "Config not loaded"
233- df , latest_ts = fetch_last_hours (cfg )
234-
235- # Preprocess
236- feature_scaler = model_state ["feature_scaler" ]
237- assert feature_scaler is not None , "Feature scaler not loaded"
238- X_seq , timestamps , in_channels , num_nodes = preprocess_for_inference (
239- data = df , feature_scaler = feature_scaler , args = cfg
240- )
241-
242- # Run inference
243- device = model_state ["device" ]
244- assert device is not None , "Device not set"
245- X_test = X_seq .to (device )
246-
247- model = model_state ["model" ]
248- edge_index = model_state ["edge_index" ]
249- edge_weight = model_state ["edge_weight" ]
250- assert model is not None , "Model not loaded"
251- assert edge_index is not None , "Edge index not loaded"
252- assert edge_weight is not None , "Edge weight not loaded"
253-
254- with torch .no_grad ():
255- preds_scaled = model (X_test , edge_index , edge_weight )
256-
257- # Inverse transform
258- target_scaler = model_state ["target_scaler" ]
259- assert target_scaler is not None , "Target scaler not loaded"
260- preds_np = preds_scaled .cpu ().numpy ()
261- preds_unscaled = target_scaler .inverse_transform (
262- preds_np .reshape (- 1 , preds_np .shape [- 1 ])
263- ).reshape (preds_np .shape )
111+ # Run full pipeline
112+ predictions , latest_ts = app .state .engine .run_full_pipeline ()
264113
265114 # Format response
266- nodes_df = model_state ["nodes_df" ]
267- assert nodes_df is not None , "Nodes dataframe not loaded"
268- num_pred_nodes = preds_unscaled .shape [2 ]
115+ num_pred_nodes = predictions .shape [2 ]
116+ pred_offsets = app .state .engine .config ["pred_offsets" ]
269117
270- predictions = []
271- for h_idx , horizon in enumerate (cfg . pred_offsets ):
118+ response = []
119+ for h_idx , horizon in enumerate (pred_offsets ):
272120 forecast_time = latest_ts + timedelta (hours = horizon )
273121
274122 for node_idx in range (num_pred_nodes ):
275- lat = nodes_df .iloc [node_idx ]["lat" ]
276- lon = nodes_df .iloc [node_idx ]["lon" ]
277- pred_val = float (preds_unscaled [0 , h_idx , node_idx , 0 ])
123+ lat = app . state . engine . nodes_df .iloc [node_idx ]["lat" ]
124+ lon = app . state . engine . nodes_df .iloc [node_idx ]["lon" ]
125+ pred_val = float (predictions [0 , h_idx , node_idx , 0 ])
278126
279- predictions .append (
127+ response .append (
280128 PredictionResponse (
281129 forecast_time = forecast_time .isoformat (),
282130 horizon_hours = horizon ,
@@ -286,7 +134,7 @@ async def run_inference(request: PredictionRequest) -> list[PredictionResponse]:
286134 )
287135 )
288136
289- return predictions
137+ return response
290138
291139 except Exception as e :
292140 raise HTTPException (
0 commit comments