1+ import os
2+ import numpy as np
13from flask import Flask , request , jsonify
24from flask_cors import CORS
3- import numpy as np
45from sklearn .datasets import load_breast_cancer
56from sklearn .ensemble import GradientBoostingClassifier
7+ from pydantic import BaseModel , ValidationError , conlist
8+ from typing import List
9+ from waitress import serve
610
711app = Flask (__name__ )
8- CORS (app , resources = {r"/*" : {"origins" : [
9- "http://localhost:5173" ,
10- "http://127.0.0.1:5173" ,
11- "http://localhost:3000" ,
12- "http://127.0.0.1:3000" ,
13- ]}})
1412
13+ # CORS Configuration via Env Vars
14+ ALLOWED_ORIGINS = os .getenv ("FRONTEND_URL" , "http://localhost:5173" ).split ("," )
15+ CORS (app , resources = {r"/*" : {"origins" : ALLOWED_ORIGINS }})
16+
17+ # Train Model on Startup
1518data = load_breast_cancer ()
1619X , y = data .data , data .target
1720model = GradientBoostingClassifier (n_estimators = 100 , random_state = 42 )
1821model .fit (X , y )
1922
23+ # Input Validation Schema
24+ class PredictionInput (BaseModel ):
25+ data : List [float ]
2026
21- @app .route ("/predict" , methods = ["POST" , "OPTIONS" ])
27+ @app .route ("/predict" , methods = ["POST" ])
2228def predict ():
23- body = request .get_json (force = True )
24- features = np .array (body ["data" ]).reshape (1 , - 1 )
25- prediction = model .predict (features )
26- result = "Benign" if prediction [0 ] == 1 else "Malignant"
27- return jsonify ({"result" : result })
29+ try :
30+ # 1. Parse & Validate Input
31+ body = request .get_json (force = True )
32+ if not body :
33+ return jsonify ({"error" : "Empty request body" }), 400
34+
35+ # Pydantic validation
36+ input_data = PredictionInput (** body )
37+
38+ # 2. Check Feature Count
39+ features = np .array (input_data .data ).reshape (1 , - 1 )
40+ if features .shape [1 ] != X .shape [1 ]:
41+ return jsonify ({
42+ "error" : f"Invalid feature count. Expected { X .shape [1 ]} , got { features .shape [1 ]} "
43+ }), 400
44+
45+ # 3. Prediction
46+ prediction = model .predict (features )
47+ result = "Benign" if prediction [0 ] == 1 else "Malignant"
48+
49+ return jsonify ({"result" : result })
50+
51+ except ValidationError as e :
52+ return jsonify ({"error" : "Validation Error" , "details" : e .errors ()}), 400
53+ except Exception as e :
54+ # Log error in production (print for now)
55+ print (f"Prediction Error: { e } " )
56+ return jsonify ({"error" : "Internal Server Error" }), 500
2857
58+ @app .route ("/health" , methods = ["GET" ])
59+ def health ():
60+ return jsonify ({"status" : "healthy" }), 200
2961
3062if __name__ == "__main__" :
31- app .run (debug = True , host = "127.0.0.1" , port = 5000 )
63+ # Production Server
64+ port = int (os .environ .get ("PORT" , 5000 ))
65+ print (f"Starting production server on port { port } ..." )
66+ serve (app , host = "0.0.0.0" , port = port )
0 commit comments