1+ import logging
12import os
23from functools import lru_cache
4+ from typing import Literal
35
46import mlflow
57import pandas as pd
68import uvicorn
7- from fastapi import FastAPI , HTTPException
9+ from fastapi import FastAPI
810from pydantic import BaseModel
911
12+ logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
13+ logger = logging .getLogger (__name__ )
14+
1015mlflow .set_tracking_uri (os .getenv ("MLFLOW_TRACKING_URI" ))
1116
1217model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
@@ -20,7 +25,7 @@ def get_model(model_uri):
2025app = FastAPI ()
2126
2227
23- class InferenceInput (BaseModel ):
28+ class PredictionInput (BaseModel ):
2429 passenger_count : int
2530 trip_type : int
2631 congestion_surcharge : float
@@ -30,20 +35,44 @@ class InferenceInput(BaseModel):
3035 vendor_id : str
3136
3237
38+ class OutputItem (BaseModel ):
39+ prediction_input : PredictionInput
40+ prediction : float
41+ status : Literal ["success" , "warning" , "failure" ]
42+ message : str
43+
44+
3345@app .get ("/" )
3446async def read_root ():
3547 return {"message" : "Welcome to the Taxi Fare Prediction API!" , "docs" : "/docs" }
3648
3749
3850@app .post ("/predict" )
39- async def predict_one (data : InferenceInput ) -> dict [ str , float ] :
51+ async def predict_one (data : PredictionInput ) -> OutputItem :
4052 try :
4153 df_input = pd .DataFrame ([data .dict ()]) # pd df might be overkill
54+ logger .info (f"[Prediction Input] Received input: { df_input } " )
4255 prediction = get_model (model_uri ).predict (df_input )
4356
44- return {"prediction" : prediction [0 ]}
57+ if prediction [0 ] <= 0 :
58+ logger .error ("[Prediction Output] Prediction failed: Negative prediction" )
59+ return OutputItem (
60+ prediction_input = data ,
61+ prediction = 0.0 ,
62+ status = "warning" ,
63+ message = "Prediction failed: Negative prediction. Check your inputs." ,
64+ )
65+ else :
66+ logger .info (f"[Prediction Output] Prediction: { prediction [0 ]} " )
67+ return OutputItem (
68+ prediction_input = data , prediction = prediction [0 ], status = "success" , message = "Prediction successful"
69+ )
70+
4571 except Exception as e :
46- raise HTTPException (status_code = 400 , detail = f"Prediction failed: { str (e )} " ) from e
72+ logger .error (f"[Prediction Output] Prediction failed: { str (e )} " )
73+ return OutputItem (
74+ prediction_input = data , prediction = - 1.0 , status = "failure" , message = f"Prediction failed: { str (e )} "
75+ )
4776
4877
4978if __name__ == "__main__" :
0 commit comments