44
55Flask dashboard application entrypoint.
66
7- This file exposes the dashboard and includes a secure route to serve explanation
8- artifacts from the project's reports/explain directory in-place (no copying).
9-
10- Run (development):
11- PYTHONPATH=src flask --app src/dashboard/app.py run
12- or
13- PYTHONPATH=src python src/dashboard/app.py
14-
15- Routes:
16- - GET / -> dashboard index
17- - POST /predict -> single prediction (form or JSON)
18- - POST /predict_batch -> batch CSV upload
19- - GET /reports/explain/<path:filename> -> serve files from reports/explain (secure)
20- - GET /api/explain_files -> list explain files (filename + mtime)
21- - GET /api/metrics -> JSON metrics
7+ Key endpoints:
8+ - GET / -> dashboard page
9+ - POST /predict -> single-record prediction (JSON or form)
10+ - POST /predict_batch -> CSV upload (multipart/form-data)
11+ - GET /api/explain_files -> JSON list of files in reports/explain
12+ - GET /reports/explain/<f> -> serve files from reports/explain securely
13+
14+ Run:
15+ PYTHONPATH=src flask --app src/dashboard/app.py run
2216"""
2317
2418from __future__ import annotations
2822import logging
2923import os
3024from logging .handlers import RotatingFileHandler
31- from typing import Any , Dict , List
3225
3326import pandas as pd
3427from flask import (
3932 redirect ,
4033 render_template ,
4134 request ,
42- send_file ,
4335 send_from_directory ,
4436 url_for ,
4537)
4638from werkzeug .utils import secure_filename
4739
48- # Ensure src package imports work when running this file directly.
49- _this_dir = os .path .dirname (os .path .abspath (__file__ )) # .../project/src/dashboard
50- _src_root = os .path .dirname (_this_dir ) # .../project/src
40+ from dashboard .predict import ModelWrapper , find_model , list_explain_files
41+
42+ # ensure src import works
43+ _this_dir = os .path .dirname (os .path .abspath (__file__ )) # .../src/dashboard
44+ _src_root = os .path .dirname (_this_dir ) # .../src
5145if _src_root not in os .sys .path :
5246 os .sys .path .insert (0 , _src_root )
47+ # Ensure server-side plotting uses Agg (lowest-risk backend for production/dev)
48+ os .environ .setdefault ("MPLBACKEND" , "Agg" )
5349
54- from dashboard .predict import ModelWrapper , find_model # noqa: E402
5550
56- BASE_DIR = _this_dir
57- REPORTS_EXPLAIN_DIR = os .path .abspath (
58- os .path .join (BASE_DIR , ".." , ".." , "reports" , "explain" )
59- )
60- UPLOAD_TMP = os .path .abspath (
61- os .path .join (BASE_DIR , ".." , ".." , "reports" , "tmp_uploads" )
62- )
63-
64- os .makedirs (UPLOAD_TMP , exist_ok = True )
51+ BASE_REPO = os .path .abspath (os .path .join (_src_root , ".." ))
52+ REPORTS_EXPLAIN_DIR = os .path .join (BASE_REPO , "reports" , "explain" )
53+ UPLOAD_TMP = os .path .join (BASE_REPO , "reports" , "tmp_uploads" )
6554os .makedirs (REPORTS_EXPLAIN_DIR , exist_ok = True )
55+ os .makedirs (UPLOAD_TMP , exist_ok = True )
6656
67- # point Flask at repo-level template dir (src/templates)
68- TEMPLATE_DIR = os .path .join (_src_root , "templates" ) # _src_root is .../project/src
6957app = Flask (
7058 __name__ ,
7159 template_folder = os .path .join (_src_root , "templates" ),
7260 static_folder = os .path .join (_src_root , "static" ),
7361)
74-
75- app .config ["MAX_CONTENT_LENGTH" ] = 5 * 1024 * 1024
62+ app .config ["MAX_CONTENT_LENGTH" ] = 8 * 1024 * 1024
7663app .config ["UPLOAD_FOLDER" ] = UPLOAD_TMP
7764app .secret_key = os .environ .get ("FLASK_SECRET" , "dev-secret-key" )
7865
66+
67+ # helper
68+ def load_wrapper (model_path = None ):
69+ """Utility to consistently load ModelWrapper."""
70+ return ModelWrapper (model_path )
71+
72+
7973# logging
8074log_path = os .path .abspath (os .path .join (_src_root , ".." , "reports" , "dashboard.log" ))
8175os .makedirs (os .path .dirname (log_path ), exist_ok = True )
82- handler = RotatingFileHandler (log_path , maxBytes = 5_000_000 , backupCount = 2 )
83- handler .setLevel (logging .INFO )
76+ handler = RotatingFileHandler (log_path , maxBytes = 2_000_000 , backupCount = 2 )
8477handler .setFormatter (logging .Formatter ("%(asctime)s %(levelname)s: %(message)s" ))
8578app .logger .addHandler (handler )
8679app .logger .setLevel (logging .INFO )
8780
8881
89- def load_wrapper (preferred : str | None = None ) -> ModelWrapper :
90- return ModelWrapper (preferred )
91-
92-
93- @app .route ("/reports/explain/<path:filename>" )
94- def explain_file (filename : str ):
95- """
96- Securely serve files from reports/explain without copying them into src/static.
97-
98- - Validates that the final absolute path is inside REPORTS_EXPLAIN_DIR.
99- - Returns 404 if file missing, 403 if attempt to escape directory.
100- """
101- safe_dir = REPORTS_EXPLAIN_DIR
102- requested = os .path .abspath (os .path .join (safe_dir , filename ))
103-
104- # Prevent directory traversal
105- if not os .path .commonpath ([safe_dir , requested ]) == safe_dir :
106- app .logger .warning ("Forbidden file access attempt: %s" , requested )
107- abort (403 )
108- if not os .path .exists (requested ):
109- abort (404 )
110- return send_from_directory (safe_dir , filename )
111-
112-
113- def _list_explain_files (extensions : List [str ] = None ) -> List [Dict [str , Any ]]:
114- """
115- Return a list of files in REPORTS_EXPLAIN_DIR filtered by extensions
116- with their modification times (as integer timestamps).
117- """
118- if extensions is None :
119- extensions = [".png" , ".jpg" , ".jpeg" , ".gif" , ".svg" , ".html" ]
120- out : List [Dict [str , Any ]] = []
121- try :
122- for name in os .listdir (REPORTS_EXPLAIN_DIR ):
123- path = os .path .join (REPORTS_EXPLAIN_DIR , name )
124- if not os .path .isfile (path ):
125- continue
126- if extensions and not any (name .lower ().endswith (ext ) for ext in extensions ):
127- continue
128- try :
129- mtime = int (os .path .getmtime (path ))
130- except Exception :
131- mtime = 0
132- out .append ({"filename" : name , "mtime" : mtime })
133- out .sort (key = lambda r : r ["mtime" ], reverse = True )
134- except FileNotFoundError :
135- return []
136- except Exception :
137- app .logger .exception ("Error listing explain files" )
138- return out
139-
140-
141- @app .route ("/api/explain_files" , methods = ["GET" ])
142- def api_explain_files ():
143- """
144- Return JSON with the list of explain files and the latest file (if any).
145-
146- Example response:
147- {
148- "ok": True,
149- "files": [{"filename":"shap_summary.png","mtime":163...}, ...],
150- "latest": {"filename":"shap_summary.png","mtime":163...} or null
151- }
152- """
153- files = _list_explain_files ()
154- latest = files [0 ] if files else None
155- return jsonify ({"ok" : True , "files" : files , "latest" : latest })
156-
157-
15882@app .route ("/" , methods = ["GET" ])
15983def index ():
160- # Try to detect a model and metrics
16184 model_path = find_model ()
16285 metrics = {}
16386 if model_path :
164- # try to find a metrics file in reports (convention: <model>_metrics.json)
87+ # try to find metrics file next to reports, e.g. reports/ <model>_metrics.json
16588 base = os .path .basename (model_path )
16689 name = os .path .splitext (base )[0 ]
167- # metrics in reports/ (sibling of reports/models)
16890 candidates = [
169- os .path .join (
170- os .path .dirname (os .path .dirname (model_path )),
171- ".." ,
172- f"{ name } _metrics.json" ,
173- ),
91+ os .path .join (os .path .dirname (model_path ), ".." , f"{ name } _metrics.json" ),
17492 os .path .join ("reports" , f"{ name } _metrics.json" ),
17593 ]
17694 for p in candidates :
@@ -185,23 +103,50 @@ def index():
185103 return render_template ("index.html" , model_path = model_path , metrics = metrics )
186104
187105
106+ @app .route ("/reports/explain/<path:filename>" )
107+ def explain_file (filename : str ):
108+ safe_dir = REPORTS_EXPLAIN_DIR
109+ requested = os .path .abspath (os .path .join (safe_dir , filename ))
110+ # directory traversal protection
111+ if not os .path .commonpath ([safe_dir , requested ]) == safe_dir :
112+ abort (403 )
113+ if not os .path .exists (requested ):
114+ abort (404 )
115+ return send_from_directory (safe_dir , filename )
116+
117+
118+ @app .route ("/api/explain_files" , methods = ["GET" ])
119+ def api_explain_files ():
120+ try :
121+ files = list_explain_files ()
122+ return jsonify (
123+ {"ok" : True , "files" : files , "latest" : files [0 ] if files else None }
124+ )
125+ except Exception as e :
126+ app .logger .exception ("Failed to list explain files" )
127+ return jsonify ({"ok" : False , "error" : str (e )}), 500
128+
129+
188130@app .route ("/predict" , methods = ["POST" ])
189131def predict ():
190- preferred_model = request .form .get ("model_path" ) or request .args .get ("model_path" )
191- data = None
132+ """
133+ Accept JSON or form data for a single record.
134+ Returns: { ok: True, result: { prediction, probability, user_message, explanation_files[...] }, model_info: {...} }
135+ """
192136 if request .is_json :
193137 data = request .get_json ()
194138 else :
195- data = {k : v for k , v in request .form .items () if k != "model_path" }
139+ data = {k : v for k , v in request .form .items ()}
140+
196141 try :
197- wrapper = load_wrapper (preferred_model )
198- if isinstance (data , dict ):
199- df = pd .DataFrame ([data ])
200- elif isinstance (data , list ):
201- df = pd .DataFrame (data )
202- else :
203- return jsonify ({"ok" : False , "error" : "Invalid input format" }), 400
142+ wrapper = ModelWrapper (os .environ .get ("DASHBOARD_MODEL" ))
143+ except Exception as e :
144+ app .logger .exception ("Model load failed" )
145+ return jsonify ({"ok" : False , "error" : "Model loading failed: " + str (e )}), 500
204146
147+ try :
148+ # coerce to DataFrame and prepare with expected features
149+ df = pd .DataFrame ([data ])
205150 res = wrapper .predict_single (df )
206151 return jsonify (
207152 {"ok" : True , "result" : res , "model_info" : wrapper .get_model_info ()}
@@ -219,7 +164,12 @@ def allowed_file(filename: str) -> bool:
219164
220165
221166@app .route ("/predict_batch" , methods = ["POST" ])
222- def predict_batch ():
167+ def predict_batch_route ():
168+ """
169+ Flask route to accept a CSV upload and return batch predictions as JSON.
170+ Keeps route at /predict_batch so existing clients/tests are unaffected.
171+ The function name is changed to avoid a name collision with ModelWrapper.predict_batch.
172+ """
223173 if "file" not in request .files :
224174 flash ("No file part" )
225175 return redirect (url_for ("index" ))
@@ -238,11 +188,25 @@ def predict_batch():
238188 return redirect (url_for ("index" ))
239189 try :
240190 wrapper = load_wrapper (None )
241- out_df = wrapper .predict_batch (df )
242- out_path = save_path .replace (".csv" , "_predictions.csv" )
243- out_df .to_csv (out_path , index = False )
244- return send_file (
245- out_path , as_attachment = True , download_name = os .path .basename (out_path )
191+ # returns dict with DataFrame in res['predictions'] and explanation_files
192+ res = wrapper .predict_batch (df )
193+ # Return JSON: n_rows, mean_probability, predictions (records), and explanation_files
194+ # Convert DataFrame -> dict (records) for JSON serialization
195+ preds_df = res .get ("predictions" )
196+ preds_records = (
197+ preds_df .to_dict (orient = "records" ) if preds_df is not None else []
198+ )
199+ return jsonify (
200+ {
201+ "ok" : True ,
202+ "result" : {
203+ "n_rows" : res .get ("n_rows" ),
204+ "mean_probability" : res .get ("mean_probability" ),
205+ "predictions" : preds_records ,
206+ "explanation_files" : res .get ("explanation_files" ),
207+ },
208+ "model_info" : wrapper .get_model_info (),
209+ }
246210 )
247211 except Exception as e :
248212 app .logger .exception ("Batch prediction failed" )
@@ -253,30 +217,6 @@ def predict_batch():
253217 return redirect (url_for ("index" ))
254218
255219
256- @app .route ("/api/metrics" , methods = ["GET" ])
257- def api_metrics ():
258- model = request .args .get ("model" )
259- metrics_path = None
260- if model :
261- metrics_path = os .path .join ("reports" , f"{ model } _metrics.json" )
262- else :
263- candidates = (
264- [p for p in os .listdir ("reports" ) if p .endswith ("_metrics.json" )]
265- if os .path .isdir ("reports" )
266- else []
267- )
268- metrics_path = os .path .join ("reports" , candidates [0 ]) if candidates else None
269-
270- if metrics_path and os .path .exists (metrics_path ):
271- try :
272- with open (metrics_path ) as fh :
273- return jsonify ({"ok" : True , "metrics" : json .load (fh )})
274- except Exception as e :
275- app .logger .exception ("Failed reading metrics" )
276- return jsonify ({"ok" : False , "error" : str (e )}), 500
277- return jsonify ({"ok" : False , "error" : "No metrics file found" }), 404
278-
279-
280220if __name__ == "__main__" :
281221 parser = argparse .ArgumentParser ()
282222 parser .add_argument ("--host" , default = "127.0.0.1" )
0 commit comments