Skip to content

Commit 58300e7

Browse files
authored
Merge pull request #7 from VectorInstitute/refactor_inference
Refactor code to create inference module, improve CLI
2 parents 1208a73 + 157a1b1 commit 58300e7

File tree

9 files changed

+1164
-928
lines changed

9 files changed

+1164
-928
lines changed

backend/app/main.py

Lines changed: 24 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,16 @@
55
predictions.
66
"""
77

8-
from argparse import Namespace
98
from datetime import datetime, timedelta
109
from pathlib import Path
11-
from typing import Any, TypedDict
10+
from typing import Any
1211

13-
import joblib
1412
import pandas as pd
15-
import torch
16-
import yaml
1713
from fastapi import FastAPI, HTTPException
1814
from fastapi.middleware.cors import CORSMiddleware
1915
from pydantic import BaseModel
20-
from sklearn.preprocessing import StandardScaler
2116

22-
from gaca_ews.core.data_extraction import fetch_last_hours
23-
from gaca_ews.core.preprocessing import preprocess_for_inference
24-
from gaca_ews.model.gcngru import GCNGRU
17+
from gaca_ews.core.inference import InferenceEngine
2518

2619

2720
app = FastAPI(
@@ -40,26 +33,6 @@
4033
)
4134

4235

43-
class Config:
44-
"""Configuration holder for model and paths."""
45-
46-
def __init__(self, config_path: str = "config.yaml") -> None:
47-
"""Initialize configuration from YAML file."""
48-
with open(config_path, "r") as f:
49-
cfg_dict = yaml.safe_load(f)
50-
51-
# Convert to namespace-like object
52-
def to_namespace(obj: Any) -> Any:
53-
if isinstance(obj, dict):
54-
ns = Namespace()
55-
for key, value in obj.items():
56-
setattr(ns, key, to_namespace(value))
57-
return ns
58-
return obj
59-
60-
self.config = to_namespace(cfg_dict)
61-
62-
6336
class PredictionRequest(BaseModel):
6437
"""Request model for inference."""
6538

@@ -87,94 +60,11 @@ class ModelInfo(BaseModel):
8760
status: str
8861

8962

90-
class ModelState(TypedDict, total=False):
91-
"""Type definition for global model state."""
92-
93-
model: GCNGRU
94-
config: Namespace
95-
artifacts_loaded: bool
96-
feature_scaler: StandardScaler
97-
target_scaler: StandardScaler
98-
edge_index: torch.Tensor
99-
edge_weight: torch.Tensor
100-
nodes_df: pd.DataFrame
101-
device: torch.device
102-
103-
104-
# Global state
105-
model_state: ModelState = {"artifacts_loaded": False}
106-
107-
108-
def load_model_artifacts() -> None:
109-
"""Load model and artifacts into memory."""
110-
if model_state["artifacts_loaded"]:
111-
return
112-
113-
config = Config()
114-
cfg = config.config
115-
116-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117-
118-
# Load scalers
119-
feature_scaler = joblib.load(cfg.model.feature_scaler_path)
120-
target_scaler = joblib.load(cfg.model.target_scaler_path)
121-
122-
# Load graph components
123-
edge_index = torch.load(cfg.graph.edge_index_path, map_location=device)
124-
edge_weight = torch.load(cfg.graph.edge_weight_path, map_location=device)
125-
nodes_df = pd.read_csv(cfg.graph.nodes_csv_path)
126-
127-
# Load model
128-
checkpoint = torch.load(cfg.model.model_path, map_location=device)
129-
model_class = checkpoint["model_class"]
130-
model_args = checkpoint["model_args"]
131-
raw_state = checkpoint["model_state_dict"]
132-
133-
# Clean state dict
134-
clean_state = {}
135-
for k, v in raw_state.items():
136-
if k.startswith("module."):
137-
clean_state[k.replace("module.", "", 1)] = v
138-
else:
139-
clean_state[k] = v
140-
141-
# Adjust model args
142-
if "pred_offsets" in model_args:
143-
model_args["pred_horizons"] = len(model_args["pred_offsets"])
144-
del model_args["pred_offsets"]
145-
if "model_name" in model_args:
146-
del model_args["model_name"]
147-
148-
model_args["num_nodes"] = edge_index.max().item() + 1
149-
150-
if model_class == "DistributedDataParallel":
151-
model_class = cfg.model_arch
152-
153-
if model_class == "GCNGRU":
154-
model = GCNGRU(**model_args)
155-
else:
156-
raise ValueError(f"Unknown model: {model_class}")
157-
158-
model.load_state_dict(clean_state)
159-
model = model.to(device)
160-
model.eval()
161-
162-
# Store in global state
163-
model_state["model"] = model
164-
model_state["feature_scaler"] = feature_scaler
165-
model_state["target_scaler"] = target_scaler
166-
model_state["edge_index"] = edge_index
167-
model_state["edge_weight"] = edge_weight
168-
model_state["nodes_df"] = nodes_df
169-
model_state["config"] = cfg
170-
model_state["device"] = device
171-
model_state["artifacts_loaded"] = True
172-
173-
17463
@app.on_event("startup")
17564
async def startup_event() -> None:
17665
"""Load model on startup."""
177-
load_model_artifacts()
66+
app.state.engine = InferenceEngine("config.yaml")
67+
app.state.engine.load_artifacts()
17868

17969

18070
@app.get("/")
@@ -190,93 +80,51 @@ async def root() -> dict[str, str]:
19080
@app.get("/health")
19181
async def health() -> dict[str, str | bool]:
19282
"""Health check endpoint."""
193-
return {"status": "healthy", "model_loaded": model_state["artifacts_loaded"]}
83+
return {"status": "healthy", "model_loaded": hasattr(app.state, "engine")}
19484

19585

19686
@app.get("/model/info", response_model=ModelInfo)
19787
async def get_model_info() -> ModelInfo:
19888
"""Get information about the loaded model."""
199-
if not model_state["artifacts_loaded"]:
89+
if not hasattr(app.state, "engine"):
20090
raise HTTPException(status_code=503, detail="Model not loaded")
20191

202-
cfg = model_state["config"]
203-
nodes_df = model_state["nodes_df"]
204-
205-
assert cfg is not None, "Config not loaded"
206-
assert nodes_df is not None, "Nodes dataframe not loaded"
92+
info = app.state.engine.get_model_info()
20793

20894
return ModelInfo(
209-
model_architecture=cfg.model_arch,
210-
num_nodes=len(nodes_df),
211-
input_features=cfg.features,
212-
prediction_horizons=cfg.pred_offsets,
213-
region={
214-
"lat_min": cfg.region.lat_min,
215-
"lat_max": cfg.region.lat_max,
216-
"lon_min": cfg.region.lon_min,
217-
"lon_max": cfg.region.lon_max,
218-
},
95+
model_architecture=info["model_architecture"],
96+
num_nodes=info["num_nodes"],
97+
input_features=info["input_features"],
98+
prediction_horizons=info["prediction_horizons"],
99+
region=info["region"],
219100
status="loaded",
220101
)
221102

222103

223104
@app.post("/predict", response_model=list[PredictionResponse])
224105
async def run_inference(request: PredictionRequest) -> list[PredictionResponse]:
225106
"""Run inference and return predictions."""
226-
if not model_state["artifacts_loaded"]:
107+
if not hasattr(app.state, "engine"):
227108
raise HTTPException(status_code=503, detail="Model not loaded")
228109

229110
try:
230-
# Fetch data
231-
cfg = model_state["config"]
232-
assert cfg is not None, "Config not loaded"
233-
df, latest_ts = fetch_last_hours(cfg)
234-
235-
# Preprocess
236-
feature_scaler = model_state["feature_scaler"]
237-
assert feature_scaler is not None, "Feature scaler not loaded"
238-
X_seq, timestamps, in_channels, num_nodes = preprocess_for_inference(
239-
data=df, feature_scaler=feature_scaler, args=cfg
240-
)
241-
242-
# Run inference
243-
device = model_state["device"]
244-
assert device is not None, "Device not set"
245-
X_test = X_seq.to(device)
246-
247-
model = model_state["model"]
248-
edge_index = model_state["edge_index"]
249-
edge_weight = model_state["edge_weight"]
250-
assert model is not None, "Model not loaded"
251-
assert edge_index is not None, "Edge index not loaded"
252-
assert edge_weight is not None, "Edge weight not loaded"
253-
254-
with torch.no_grad():
255-
preds_scaled = model(X_test, edge_index, edge_weight)
256-
257-
# Inverse transform
258-
target_scaler = model_state["target_scaler"]
259-
assert target_scaler is not None, "Target scaler not loaded"
260-
preds_np = preds_scaled.cpu().numpy()
261-
preds_unscaled = target_scaler.inverse_transform(
262-
preds_np.reshape(-1, preds_np.shape[-1])
263-
).reshape(preds_np.shape)
111+
# Run full pipeline
112+
predictions, latest_ts = app.state.engine.run_full_pipeline()
264113

265114
# Format response
266-
nodes_df = model_state["nodes_df"]
267-
assert nodes_df is not None, "Nodes dataframe not loaded"
268-
num_pred_nodes = preds_unscaled.shape[2]
115+
num_pred_nodes = predictions.shape[2]
116+
pred_offsets = app.state.engine.config["pred_offsets"]
269117

270-
predictions = []
271-
for h_idx, horizon in enumerate(cfg.pred_offsets):
118+
response = []
119+
for h_idx, horizon in enumerate(pred_offsets):
272120
forecast_time = latest_ts + timedelta(hours=horizon)
273121

274122
for node_idx in range(num_pred_nodes):
275-
lat = nodes_df.iloc[node_idx]["lat"]
276-
lon = nodes_df.iloc[node_idx]["lon"]
277-
pred_val = float(preds_unscaled[0, h_idx, node_idx, 0])
123+
lat = app.state.engine.nodes_df.iloc[node_idx]["lat"]
124+
lon = app.state.engine.nodes_df.iloc[node_idx]["lon"]
125+
pred_val = float(predictions[0, h_idx, node_idx, 0])
278126

279-
predictions.append(
127+
response.append(
280128
PredictionResponse(
281129
forecast_time=forecast_time.isoformat(),
282130
horizon_hours=horizon,
@@ -286,7 +134,7 @@ async def run_inference(request: PredictionRequest) -> list[PredictionResponse]:
286134
)
287135
)
288136

289-
return predictions
137+
return response
290138

291139
except Exception as e:
292140
raise HTTPException(

config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model:
2-
feature_scaler_path: "./model/feature_scaler.pkl"
3-
target_scaler_path: "./model/target_scaler.pkl"
4-
model_path: "./model/final_model.pth"
2+
feature_scaler_path: "./src/gaca_ews/model/feature_scaler.pkl"
3+
target_scaler_path: "./src/gaca_ews/model/target_scaler.pkl"
4+
model_path: "./src/gaca_ews/model/final_model.pth"
55

66
graph:
77
edge_index_path: "./data/edge_index.pt"

0 commit comments

Comments
 (0)