Skip to content

Commit bab91cb

Browse files
committed
.pre-commit-config.py
1 parent 8bd556c commit bab91cb

File tree

3 files changed

+132
-154
lines changed

3 files changed

+132
-154
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ repos:
3838

3939
# Ruff: fast linter (combines flake8-like checks + autofix)
4040
- repo: https://github.com/charliermarsh/ruff-pre-commit
41-
rev: v0.14.2
41+
rev: v0.14.3
4242
hooks:
4343
- id: ruff
4444
args: ["--fix"] # remove `--fix` if you prefer only diagnostics

src/dashboard/app.py

Lines changed: 93 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,15 @@
44
55
Flask dashboard application entrypoint.
66
7-
This file exposes the dashboard and includes a secure route to serve explanation
8-
artifacts from the project's reports/explain directory in-place (no copying).
9-
10-
Run (development):
11-
PYTHONPATH=src flask --app src/dashboard/app.py run
12-
or
13-
PYTHONPATH=src python src/dashboard/app.py
14-
15-
Routes:
16-
- GET / -> dashboard index
17-
- POST /predict -> single prediction (form or JSON)
18-
- POST /predict_batch -> batch CSV upload
19-
- GET /reports/explain/<path:filename> -> serve files from reports/explain (secure)
20-
- GET /api/explain_files -> list explain files (filename + mtime)
21-
- GET /api/metrics -> JSON metrics
7+
Key endpoints:
8+
- GET / -> dashboard page
9+
- POST /predict -> single-record prediction (JSON or form)
10+
- POST /predict_batch -> CSV upload (multipart/form-data)
11+
- GET /api/explain_files -> JSON list of files in reports/explain
12+
- GET /reports/explain/<f> -> serve files from reports/explain securely
13+
14+
Run:
15+
PYTHONPATH=src flask --app src/dashboard/app.py run
2216
"""
2317

2418
from __future__ import annotations
@@ -28,7 +22,6 @@
2822
import logging
2923
import os
3024
from logging.handlers import RotatingFileHandler
31-
from typing import Any, Dict, List
3225

3326
import pandas as pd
3427
from flask import (
@@ -39,138 +32,63 @@
3932
redirect,
4033
render_template,
4134
request,
42-
send_file,
4335
send_from_directory,
4436
url_for,
4537
)
4638
from werkzeug.utils import secure_filename
4739

48-
# Ensure src package imports work when running this file directly.
49-
_this_dir = os.path.dirname(os.path.abspath(__file__)) # .../project/src/dashboard
50-
_src_root = os.path.dirname(_this_dir) # .../project/src
40+
from dashboard.predict import ModelWrapper, find_model, list_explain_files
41+
42+
# ensure src import works
43+
_this_dir = os.path.dirname(os.path.abspath(__file__)) # .../src/dashboard
44+
_src_root = os.path.dirname(_this_dir) # .../src
5145
if _src_root not in os.sys.path:
5246
os.sys.path.insert(0, _src_root)
47+
# Ensure server-side plotting uses Agg (lowest-risk backend for production/dev)
48+
os.environ.setdefault("MPLBACKEND", "Agg")
5349

54-
from dashboard.predict import ModelWrapper, find_model # noqa: E402
5550

56-
BASE_DIR = _this_dir
57-
REPORTS_EXPLAIN_DIR = os.path.abspath(
58-
os.path.join(BASE_DIR, "..", "..", "reports", "explain")
59-
)
60-
UPLOAD_TMP = os.path.abspath(
61-
os.path.join(BASE_DIR, "..", "..", "reports", "tmp_uploads")
62-
)
63-
64-
os.makedirs(UPLOAD_TMP, exist_ok=True)
51+
BASE_REPO = os.path.abspath(os.path.join(_src_root, ".."))
52+
REPORTS_EXPLAIN_DIR = os.path.join(BASE_REPO, "reports", "explain")
53+
UPLOAD_TMP = os.path.join(BASE_REPO, "reports", "tmp_uploads")
6554
os.makedirs(REPORTS_EXPLAIN_DIR, exist_ok=True)
55+
os.makedirs(UPLOAD_TMP, exist_ok=True)
6656

67-
# point Flask at repo-level template dir (src/templates)
68-
TEMPLATE_DIR = os.path.join(_src_root, "templates") # _src_root is .../project/src
6957
app = Flask(
7058
__name__,
7159
template_folder=os.path.join(_src_root, "templates"),
7260
static_folder=os.path.join(_src_root, "static"),
7361
)
74-
75-
app.config["MAX_CONTENT_LENGTH"] = 5 * 1024 * 1024
62+
app.config["MAX_CONTENT_LENGTH"] = 8 * 1024 * 1024
7663
app.config["UPLOAD_FOLDER"] = UPLOAD_TMP
7764
app.secret_key = os.environ.get("FLASK_SECRET", "dev-secret-key")
7865

66+
67+
# helper
68+
def load_wrapper(model_path=None):
69+
"""Utility to consistently load ModelWrapper."""
70+
return ModelWrapper(model_path)
71+
72+
7973
# logging
8074
log_path = os.path.abspath(os.path.join(_src_root, "..", "reports", "dashboard.log"))
8175
os.makedirs(os.path.dirname(log_path), exist_ok=True)
82-
handler = RotatingFileHandler(log_path, maxBytes=5_000_000, backupCount=2)
83-
handler.setLevel(logging.INFO)
76+
handler = RotatingFileHandler(log_path, maxBytes=2_000_000, backupCount=2)
8477
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s: %(message)s"))
8578
app.logger.addHandler(handler)
8679
app.logger.setLevel(logging.INFO)
8780

8881

89-
def load_wrapper(preferred: str | None = None) -> ModelWrapper:
90-
return ModelWrapper(preferred)
91-
92-
93-
@app.route("/reports/explain/<path:filename>")
94-
def explain_file(filename: str):
95-
"""
96-
Securely serve files from reports/explain without copying them into src/static.
97-
98-
- Validates that the final absolute path is inside REPORTS_EXPLAIN_DIR.
99-
- Returns 404 if file missing, 403 if attempt to escape directory.
100-
"""
101-
safe_dir = REPORTS_EXPLAIN_DIR
102-
requested = os.path.abspath(os.path.join(safe_dir, filename))
103-
104-
# Prevent directory traversal
105-
if not os.path.commonpath([safe_dir, requested]) == safe_dir:
106-
app.logger.warning("Forbidden file access attempt: %s", requested)
107-
abort(403)
108-
if not os.path.exists(requested):
109-
abort(404)
110-
return send_from_directory(safe_dir, filename)
111-
112-
113-
def _list_explain_files(extensions: List[str] = None) -> List[Dict[str, Any]]:
114-
"""
115-
Return a list of files in REPORTS_EXPLAIN_DIR filtered by extensions
116-
with their modification times (as integer timestamps).
117-
"""
118-
if extensions is None:
119-
extensions = [".png", ".jpg", ".jpeg", ".gif", ".svg", ".html"]
120-
out: List[Dict[str, Any]] = []
121-
try:
122-
for name in os.listdir(REPORTS_EXPLAIN_DIR):
123-
path = os.path.join(REPORTS_EXPLAIN_DIR, name)
124-
if not os.path.isfile(path):
125-
continue
126-
if extensions and not any(name.lower().endswith(ext) for ext in extensions):
127-
continue
128-
try:
129-
mtime = int(os.path.getmtime(path))
130-
except Exception:
131-
mtime = 0
132-
out.append({"filename": name, "mtime": mtime})
133-
out.sort(key=lambda r: r["mtime"], reverse=True)
134-
except FileNotFoundError:
135-
return []
136-
except Exception:
137-
app.logger.exception("Error listing explain files")
138-
return out
139-
140-
141-
@app.route("/api/explain_files", methods=["GET"])
142-
def api_explain_files():
143-
"""
144-
Return JSON with the list of explain files and the latest file (if any).
145-
146-
Example response:
147-
{
148-
"ok": True,
149-
"files": [{"filename":"shap_summary.png","mtime":163...}, ...],
150-
"latest": {"filename":"shap_summary.png","mtime":163...} or null
151-
}
152-
"""
153-
files = _list_explain_files()
154-
latest = files[0] if files else None
155-
return jsonify({"ok": True, "files": files, "latest": latest})
156-
157-
15882
@app.route("/", methods=["GET"])
15983
def index():
160-
# Try to detect a model and metrics
16184
model_path = find_model()
16285
metrics = {}
16386
if model_path:
164-
# try to find a metrics file in reports (convention: <model>_metrics.json)
87+
# try to find metrics file next to reports, e.g. reports/<model>_metrics.json
16588
base = os.path.basename(model_path)
16689
name = os.path.splitext(base)[0]
167-
# metrics in reports/ (sibling of reports/models)
16890
candidates = [
169-
os.path.join(
170-
os.path.dirname(os.path.dirname(model_path)),
171-
"..",
172-
f"{name}_metrics.json",
173-
),
91+
os.path.join(os.path.dirname(model_path), "..", f"{name}_metrics.json"),
17492
os.path.join("reports", f"{name}_metrics.json"),
17593
]
17694
for p in candidates:
@@ -185,23 +103,50 @@ def index():
185103
return render_template("index.html", model_path=model_path, metrics=metrics)
186104

187105

106+
@app.route("/reports/explain/<path:filename>")
107+
def explain_file(filename: str):
108+
safe_dir = REPORTS_EXPLAIN_DIR
109+
requested = os.path.abspath(os.path.join(safe_dir, filename))
110+
# directory traversal protection
111+
if not os.path.commonpath([safe_dir, requested]) == safe_dir:
112+
abort(403)
113+
if not os.path.exists(requested):
114+
abort(404)
115+
return send_from_directory(safe_dir, filename)
116+
117+
118+
@app.route("/api/explain_files", methods=["GET"])
119+
def api_explain_files():
120+
try:
121+
files = list_explain_files()
122+
return jsonify(
123+
{"ok": True, "files": files, "latest": files[0] if files else None}
124+
)
125+
except Exception as e:
126+
app.logger.exception("Failed to list explain files")
127+
return jsonify({"ok": False, "error": str(e)}), 500
128+
129+
188130
@app.route("/predict", methods=["POST"])
189131
def predict():
190-
preferred_model = request.form.get("model_path") or request.args.get("model_path")
191-
data = None
132+
"""
133+
Accept JSON or form data for a single record.
134+
Returns: { ok: True, result: { prediction, probability, user_message, explanation_files[...] }, model_info: {...} }
135+
"""
192136
if request.is_json:
193137
data = request.get_json()
194138
else:
195-
data = {k: v for k, v in request.form.items() if k != "model_path"}
139+
data = {k: v for k, v in request.form.items()}
140+
196141
try:
197-
wrapper = load_wrapper(preferred_model)
198-
if isinstance(data, dict):
199-
df = pd.DataFrame([data])
200-
elif isinstance(data, list):
201-
df = pd.DataFrame(data)
202-
else:
203-
return jsonify({"ok": False, "error": "Invalid input format"}), 400
142+
wrapper = ModelWrapper(os.environ.get("DASHBOARD_MODEL"))
143+
except Exception as e:
144+
app.logger.exception("Model load failed")
145+
return jsonify({"ok": False, "error": "Model loading failed: " + str(e)}), 500
204146

147+
try:
148+
# coerce to DataFrame and prepare with expected features
149+
df = pd.DataFrame([data])
205150
res = wrapper.predict_single(df)
206151
return jsonify(
207152
{"ok": True, "result": res, "model_info": wrapper.get_model_info()}
@@ -219,7 +164,12 @@ def allowed_file(filename: str) -> bool:
219164

220165

221166
@app.route("/predict_batch", methods=["POST"])
222-
def predict_batch():
167+
def predict_batch_route():
168+
"""
169+
Flask route to accept a CSV upload and return batch predictions as JSON.
170+
Keeps route at /predict_batch so existing clients/tests are unaffected.
171+
The function name is changed to avoid a name collision with ModelWrapper.predict_batch.
172+
"""
223173
if "file" not in request.files:
224174
flash("No file part")
225175
return redirect(url_for("index"))
@@ -238,11 +188,25 @@ def predict_batch():
238188
return redirect(url_for("index"))
239189
try:
240190
wrapper = load_wrapper(None)
241-
out_df = wrapper.predict_batch(df)
242-
out_path = save_path.replace(".csv", "_predictions.csv")
243-
out_df.to_csv(out_path, index=False)
244-
return send_file(
245-
out_path, as_attachment=True, download_name=os.path.basename(out_path)
191+
# returns dict with DataFrame in res['predictions'] and explanation_files
192+
res = wrapper.predict_batch(df)
193+
# Return JSON: n_rows, mean_probability, predictions (records), and explanation_files
194+
# Convert DataFrame -> dict (records) for JSON serialization
195+
preds_df = res.get("predictions")
196+
preds_records = (
197+
preds_df.to_dict(orient="records") if preds_df is not None else []
198+
)
199+
return jsonify(
200+
{
201+
"ok": True,
202+
"result": {
203+
"n_rows": res.get("n_rows"),
204+
"mean_probability": res.get("mean_probability"),
205+
"predictions": preds_records,
206+
"explanation_files": res.get("explanation_files"),
207+
},
208+
"model_info": wrapper.get_model_info(),
209+
}
246210
)
247211
except Exception as e:
248212
app.logger.exception("Batch prediction failed")
@@ -253,30 +217,6 @@ def predict_batch():
253217
return redirect(url_for("index"))
254218

255219

256-
@app.route("/api/metrics", methods=["GET"])
257-
def api_metrics():
258-
model = request.args.get("model")
259-
metrics_path = None
260-
if model:
261-
metrics_path = os.path.join("reports", f"{model}_metrics.json")
262-
else:
263-
candidates = (
264-
[p for p in os.listdir("reports") if p.endswith("_metrics.json")]
265-
if os.path.isdir("reports")
266-
else []
267-
)
268-
metrics_path = os.path.join("reports", candidates[0]) if candidates else None
269-
270-
if metrics_path and os.path.exists(metrics_path):
271-
try:
272-
with open(metrics_path) as fh:
273-
return jsonify({"ok": True, "metrics": json.load(fh)})
274-
except Exception as e:
275-
app.logger.exception("Failed reading metrics")
276-
return jsonify({"ok": False, "error": str(e)}), 500
277-
return jsonify({"ok": False, "error": "No metrics file found"}), 404
278-
279-
280220
if __name__ == "__main__":
281221
parser = argparse.ArgumentParser()
282222
parser.add_argument("--host", default="127.0.0.1")

tests/test_app_endpoints.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Integration tests for Flask endpoints (predict, predict_batch, api_explain_files).
3+
"""
4+
5+
6+
def test_predict_endpoint_single(client):
7+
payload = {
8+
"Pregnancies": 0,
9+
"Glucose": 110,
10+
"BloodPressure": 72,
11+
"SkinThickness": 20,
12+
"Insulin": 85,
13+
"BMI": 28.5,
14+
"DiabetesPedigreeFunction": 0.4,
15+
"Age": 30,
16+
}
17+
res = client.post("/predict", json=payload)
18+
assert res.status_code == 200
19+
data = res.get_json()
20+
assert data["ok"] is True
21+
assert "result" in data
22+
assert "user_message" in data["result"]
23+
24+
25+
def test_predict_batch_endpoint(client, sample_csv_path):
26+
# upload sample csv
27+
with open(sample_csv_path, "rb") as fh:
28+
data = {"file": (fh, "sample.csv")}
29+
res = client.post(
30+
"/predict_batch", data=data, content_type="multipart/form-data"
31+
)
32+
assert res.status_code == 200
33+
payload = res.get_json()
34+
assert payload["ok"] is True
35+
assert "result" in payload
36+
r = payload["result"]
37+
assert "n_rows" in r and r["n_rows"] > 0
38+
assert "mean_probability" in r

0 commit comments

Comments
 (0)