Skip to content

Commit 8bd556c

Browse files
committed
Implement ModelWrapper class
1 parent e0ef75f commit 8bd556c

File tree

1 file changed

+292
-66
lines changed

1 file changed

+292
-66
lines changed

src/dashboard/predict.py

Lines changed: 292 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,327 @@
11
"""
22
src/dashboard/predict.py
33
4-
Lightweight wrapper to load the trained model and run predictions.
5-
This isolates model handling for tests and the Flask app.
4+
ModelWrapper: loads a saved joblib/sklearn pipeline and exposes:
5+
- get_expected_features()
6+
- predict_single(df)
7+
- predict_batch(df)
8+
9+
The methods return rich dictionaries used by the dashboard endpoints.
10+
11+
Notes:
12+
- If the saved model cannot be found, the wrapper will attempt to locate the
13+
first '*_best.joblib' under 'reports/models/'.
14+
- For explainability, this module will generate a minimal visualization PNG for
15+
single and batch predictions (simple bar/histogram) and save into 'reports/explain/'.
16+
- Optional SHAP support is used if installed, but not required.
617
"""
718

819
from __future__ import annotations
920

1021
import glob
1122
import logging
1223
import os
13-
from typing import Optional
24+
from typing import Any, Dict, List, Optional
1425

1526
import joblib
27+
import matplotlib.pyplot as plt
1628
import numpy as np
1729
import pandas as pd
1830

19-
logger = logging.getLogger(__name__)
31+
from .utils import artifact_name, list_files_with_mtime, safe_prepare_df
2032

33+
# Optional SHAP
34+
try:
35+
import shap # type: ignore
2136

22-
def find_model(
23-
models_dir: str = "reports/models", preferred: Optional[str] = None
24-
) -> Optional[str]:
25-
"""
26-
Find the first matching model artifact. If preferred is given, prefer that file.
27-
"""
28-
if preferred and os.path.exists(preferred):
29-
return preferred
30-
31-
if not os.path.isdir(models_dir):
32-
return None
33-
34-
# try explicit patterns
35-
patterns = [
36-
os.path.join(models_dir, "*_best.joblib"),
37-
os.path.join(models_dir, "*_best.pkl"),
38-
os.path.join(models_dir, "*.joblib"),
39-
os.path.join(models_dir, "*.pkl"),
40-
]
41-
for pat in patterns:
42-
matches = glob.glob(pat)
43-
if matches:
44-
return matches[0]
45-
return None
37+
SHAP_AVAILABLE = True
38+
except Exception:
39+
SHAP_AVAILABLE = False
4640

41+
LOGGER = logging.getLogger(__name__)
42+
BASE_REPO = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
43+
REPORTS_EXPLAIN_DIR = os.path.join(BASE_REPO, "reports", "explain")
44+
REPORTS_MODELS_DIR = os.path.join(BASE_REPO, "reports", "models")
45+
os.makedirs(REPORTS_EXPLAIN_DIR, exist_ok=True)
4746

48-
class ModelWrapper:
49-
"""
50-
Responsible for loading the model artifact and exposing predict/predict_proba methods.
51-
Accepts sklearn Pipelines and raw estimators saved with joblib.
52-
"""
5347

48+
class ModelWrapper:
5449
def __init__(self, model_path: Optional[str] = None):
55-
self.model_path = model_path or find_model()
56-
if self.model_path is None:
57-
raise FileNotFoundError("No model artifact found in reports/models/")
58-
logger.info("Loading model from %s", self.model_path)
50+
"""
51+
Load model. If model_path not provided, try environment var DASHBOARD_MODEL,
52+
then look for reports/models/*_best.joblib.
53+
"""
54+
path = model_path or os.environ.get("DASHBOARD_MODEL")
55+
if path is None:
56+
# search reports/models
57+
candidates = glob.glob(os.path.join(REPORTS_MODELS_DIR, "*_best.*"))
58+
path = candidates[0] if candidates else None
59+
if path is None or not os.path.exists(path):
60+
raise FileNotFoundError(
61+
"Could not find model file. Provide path or set DASHBOARD_MODEL."
62+
)
63+
self.model_path = os.path.abspath(path)
5964
self.model = joblib.load(self.model_path)
65+
# determine expected features
66+
self.expected_features = self._extract_feature_names()
67+
LOGGER.info(
68+
"Loaded model %s expecting features: %s",
69+
self.model_path,
70+
self.expected_features,
71+
)
6072

61-
def predict_single(self, X: pd.DataFrame) -> dict:
73+
def _extract_feature_names(self) -> List[str]:
6274
"""
63-
Predict a single-row DataFrame (or 1d array-like converted to DataFrame).
64-
Returns {"prediction": int, "probability": float}
75+
Try multiple heuristics to extract feature names the model was trained on:
76+
- model.feature_names_in_
77+
- if Pipeline, final estimator.feature_names_in_
78+
- fall back to a common diabetes feature set if unknown
6579
"""
66-
if isinstance(X, (list, tuple, np.ndarray)):
67-
X = pd.DataFrame([X])
68-
elif isinstance(X, dict):
69-
X = pd.DataFrame([X])
70-
if not isinstance(X, pd.DataFrame):
71-
raise ValueError("X must be a pandas DataFrame, dict, list or array")
72-
73-
pred = int(self.model.predict(X)[0])
74-
prob = None
80+
# common fallback for this project
81+
fallback = [
82+
"Pregnancies",
83+
"Glucose",
84+
"BloodPressure",
85+
"SkinThickness",
86+
"Insulin",
87+
"BMI",
88+
"DiabetesPedigreeFunction",
89+
"Age",
90+
]
91+
92+
# direct attr
7593
try:
76-
prob = float(self.model.predict_proba(X)[0, 1])
94+
fn = getattr(self.model, "feature_names_in_", None)
95+
if fn is not None:
96+
return list(fn)
7797
except Exception:
78-
# some models don't support predict_proba
79-
prob = None
80-
return {"prediction": pred, "probability": prob}
98+
pass
8199

82-
def predict_batch(self, df: pd.DataFrame) -> pd.DataFrame:
100+
# pipeline final estimator
101+
try:
102+
# pipeline: named_steps or steps
103+
if hasattr(self.model, "named_steps"):
104+
final = list(self.model.named_steps.values())[-1]
105+
fn = getattr(final, "feature_names_in_", None)
106+
if fn is not None:
107+
return list(fn)
108+
# some estimators store feature names in coef_ shapes etc; not reliable
109+
except Exception:
110+
pass
111+
112+
return fallback
113+
114+
def get_expected_features(self) -> List[str]:
115+
return list(self.expected_features)
116+
117+
def get_model_info(self) -> Dict[str, Any]:
118+
return {"estimator": type(self.model).__name__, "model_path": self.model_path}
119+
120+
def predict_single(self, df: pd.DataFrame) -> Dict[str, Any]:
83121
"""
84-
Predict on a DataFrame and return the DataFrame with added columns:
85-
'prediction' and 'probability' (if available).
122+
Accepts a one-row DataFrame or DataFrame with one record.
123+
Returns a dict with keys:
124+
- prediction (int)
125+
- probability (float 0..1)
126+
- user_message (string)
127+
- explanation_files: list of {filename, mtime}
86128
"""
87-
df_copy = df.copy()
88-
preds = self.model.predict(df_copy)
89-
df_copy["prediction"] = preds
129+
if df.shape[0] < 1:
130+
raise ValueError("Empty input for single prediction")
131+
expected = self.get_expected_features()
132+
df_prepped = safe_prepare_df(df.iloc[[0]], expected)
133+
134+
# predict
135+
pred = None
136+
prob = None
137+
try:
138+
pred_arr = self.model.predict(df_prepped)
139+
pred = int(pred_arr[0])
140+
except Exception:
141+
pred = None
142+
try:
143+
prob_arr = self.model.predict_proba(df_prepped)[:, 1]
144+
prob = float(prob_arr[0])
145+
except Exception:
146+
prob = None
147+
148+
# friendly message
149+
pct = (prob * 100) if prob is not None else None
150+
if pct is None:
151+
user_message = "The model could not compute a probability for this input."
152+
else:
153+
user_message = (
154+
f"Based on the details you shared, our model estimates there’s about "
155+
f"{pct:.2f}% chance you may be at risk of developing diabetes. "
156+
"This isn’t a medical diagnosis — consult a healthcare professional for personalised advice."
157+
)
158+
159+
# produce a small bar chart png for this single prediction
160+
pngname = artifact_name("shap_single_pred", "png")
161+
pngpath = os.path.join(REPORTS_EXPLAIN_DIR, pngname)
90162
try:
91-
probs = self.model.predict_proba(df_copy)[:, 1]
92-
df_copy["probability"] = probs
163+
plt.figure(figsize=(4, 2))
164+
val = pct if pct is not None else 0.0
165+
plt.barh([0], [val], height=0.6)
166+
plt.xlim(0, 100)
167+
plt.xlabel("Risk (%)")
168+
plt.yticks([])
169+
plt.title("Predicted risk (%)")
170+
plt.tight_layout()
171+
plt.savefig(pngpath, bbox_inches="tight")
172+
plt.close()
93173
except Exception:
94-
df_copy["probability"] = pd.NA
95-
return df_copy
174+
LOGGER.exception("Failed to create single prediction PNG")
175+
pngname = None
176+
177+
files = []
178+
if pngname:
179+
try:
180+
m = int(os.path.getmtime(os.path.join(REPORTS_EXPLAIN_DIR, pngname)))
181+
files.append({"filename": pngname, "mtime": m})
182+
except Exception:
183+
files.append({"filename": pngname, "mtime": 0})
184+
185+
# optional SHAP explanation (best-effort)
186+
if SHAP_AVAILABLE:
187+
try:
188+
expl = shap.Explainer(self.model, df_prepped)
189+
sv = expl(df_prepped)
190+
htmlname = artifact_name("shap_single_force", "html")
191+
htmlpath = os.path.join(REPORTS_EXPLAIN_DIR, htmlname)
192+
# try to produce a self-contained html via shap (best-effort)
193+
try:
194+
force = shap.plots.force(sv, matplotlib=False)
195+
with open(htmlpath, "w", encoding="utf-8") as fh:
196+
fh.write(force.html())
197+
m = int(os.path.getmtime(htmlpath))
198+
files.append({"filename": htmlname, "mtime": m})
199+
except Exception:
200+
# fallback: save a simple text file
201+
pass
202+
except Exception:
203+
LOGGER.debug("SHAP explain not produced (optional)")
96204

97-
def get_model_info(self) -> dict:
205+
return {
206+
"prediction": pred,
207+
"probability": prob if prob is not None else 0.0,
208+
"user_message": user_message,
209+
"explanation_files": files,
210+
"model_info": self.get_model_info(),
211+
}
212+
213+
def predict_batch(self, df: pd.DataFrame) -> Dict[str, Any]:
98214
"""
99-
Return meta info about loaded model (path, class name).
215+
Accepts a DataFrame (possibly with Outcome). Drops Outcome if present,
216+
prepares DataFrame to expected features, and returns summary dict:
217+
- n_rows, n_positive, mean_probability, hist_bins, hist_counts, explanation_files, model_info
218+
Also saves a histogram PNG under reports/explain/.
100219
"""
101-
return {"model_path": self.model_path, "estimator": type(self.model).__name__}
220+
if df.shape[0] < 1:
221+
raise ValueError("Empty DataFrame uploaded")
222+
expected = self.get_expected_features()
223+
# drop Outcome if present and ensure expected columns
224+
if "Outcome" in df.columns:
225+
df = df.drop(columns=["Outcome"])
226+
227+
df_prepped = safe_prepare_df(df, expected)
228+
229+
# predict probabilities if possible
230+
probs = None
231+
preds = None
232+
try:
233+
probs = self.model.predict_proba(df_prepped)[:, 1]
234+
preds = (probs >= 0.5).astype(int)
235+
except Exception:
236+
try:
237+
preds = self.model.predict(df_prepped)
238+
probs = np.zeros_like(preds, dtype=float)
239+
except Exception:
240+
raise
241+
242+
n_rows = int(len(df_prepped))
243+
n_positive = int(int(np.sum(preds)))
244+
mean_prob = float(float(np.mean(probs))) if len(probs) > 0 else 0.0
245+
246+
# histogram
247+
counts, bins = np.histogram(probs, bins=10, range=(0.0, 1.0))
248+
bin_labels = [
249+
f"{int(b*100)}-{int(bins[i+1]*100)}%" for i, b in enumerate(bins[:-1])
250+
]
251+
252+
# save histogram png
253+
pngname = artifact_name("batch_pred_hist", "png")
254+
pngpath = os.path.join(REPORTS_EXPLAIN_DIR, pngname)
255+
try:
256+
plt.figure(figsize=(6, 3))
257+
plt.bar(range(len(counts)), counts)
258+
plt.xticks(range(len(counts)), bin_labels, rotation=45, ha="right")
259+
plt.ylabel("Count")
260+
plt.title("Prediction probability distribution")
261+
plt.tight_layout()
262+
plt.savefig(pngpath, bbox_inches="tight")
263+
plt.close()
264+
except Exception:
265+
LOGGER.exception("Failed to save batch histogram")
266+
pngname = None
267+
268+
files = []
269+
if pngname:
270+
try:
271+
m = int(os.path.getmtime(os.path.join(REPORTS_EXPLAIN_DIR, pngname)))
272+
files.append({"filename": pngname, "mtime": m})
273+
except Exception:
274+
files.append({"filename": pngname, "mtime": 0})
275+
276+
# produce permutation importance csv (best-effort, may be slow)
277+
try:
278+
from sklearn.inspection import permutation_importance
279+
280+
r = permutation_importance(
281+
self.model,
282+
df_prepped,
283+
np.asarray(preds),
284+
n_repeats=5,
285+
random_state=42,
286+
n_jobs=1,
287+
)
288+
imp_df = pd.DataFrame(
289+
{
290+
"feature": expected,
291+
"importance_mean": r.importances_mean,
292+
"importance_std": r.importances_std,
293+
}
294+
)
295+
csvname = artifact_name("permutation_importance", "csv")
296+
csvpath = os.path.join(REPORTS_EXPLAIN_DIR, csvname)
297+
imp_df.to_csv(csvpath, index=False)
298+
m = int(os.path.getmtime(csvpath))
299+
files.append({"filename": csvname, "mtime": m})
300+
except Exception:
301+
LOGGER.debug("Permutation importance not generated (optional)")
302+
303+
return {
304+
"n_rows": n_rows,
305+
"n_positive": n_positive,
306+
"mean_probability": mean_prob,
307+
"hist_counts": counts.tolist(),
308+
"hist_bins": bin_labels,
309+
"explanation_files": files,
310+
"model_info": self.get_model_info(),
311+
}
312+
313+
314+
def find_model() -> Optional[str]:
315+
"""
316+
Return currently configured model path (env DASHBOARD_MODEL) or first candidate.
317+
"""
318+
path = os.environ.get("DASHBOARD_MODEL")
319+
if path and os.path.exists(path):
320+
return os.path.abspath(path)
321+
candidates = glob.glob(os.path.join(REPORTS_MODELS_DIR, "*_best.*"))
322+
return candidates[0] if candidates else None
323+
324+
325+
# Convenience for listing explain files
326+
def list_explain_files() -> List[Dict[str, Any]]:
327+
return list_files_with_mtime(REPORTS_EXPLAIN_DIR)

0 commit comments

Comments
 (0)