File tree Expand file tree Collapse file tree 1 file changed +17
-8
lines changed
Expand file tree Collapse file tree 1 file changed +17
-8
lines changed Original file line number Diff line number Diff line change 11import os
2+ from functools import lru_cache
23
34import mlflow
45import pandas as pd
56import uvicorn
6- from fastapi import FastAPI
7+ from fastapi import FastAPI , HTTPException
78from pydantic import BaseModel
89
910mlflow .set_tracking_uri (os .getenv ("MLFLOW_TRACKING_URI" ))
1011
1112model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
12- pipe = mlflow .sklearn .load_model (model_uri )
13+
14+
15+ @lru_cache
16+ def get_model ():
17+ return mlflow .sklearn .load_model (model_uri )
18+
19+
20+ pipe = get_model ()
1321
1422app = FastAPI ()
1523
@@ -26,16 +34,17 @@ class InferenceInput(BaseModel):
2634
2735@app .get ("/" )
2836async def read_root ():
29- return {"Hello " : "Visit /docs for API documentation " }
37+ return {"message " : "Welcome to the Taxi Fare Prediction API!" , "docs" : "/docs " }
3038
3139
3240@app .post ("/predict" )
3341async def predict_one (data : InferenceInput ) -> dict [str , float ]:
34- df_input = pd .DataFrame ([data .dict ()])
35-
36- prediction = pipe .predict (df_input )
37-
38- return {"prediction" : prediction [0 ]}
42+ try :
43+ df_input = pd .DataFrame ([data .dict ()]) # pd df might be overkill
44+ prediction = pipe .predict (df_input )
45+ return {"prediction" : prediction [0 ]}
46+ except Exception as e :
47+ raise HTTPException (status_code = 400 , detail = f"Prediction failed: { str (e )} " ) from e
3948
4049
4150if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments