Skip to content

Commit f7fc205

Browse files
committed
update model serving api with extra checks
1 parent f77b54c commit f7fc205

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/make_api/app/main.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import os
2+
from functools import lru_cache
23

34
import mlflow
45
import pandas as pd
56
import uvicorn
6-
from fastapi import FastAPI
7+
from fastapi import FastAPI, HTTPException
78
from pydantic import BaseModel
89

910
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
1011

1112
model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
12-
pipe = mlflow.sklearn.load_model(model_uri)
13+
14+
15+
@lru_cache
16+
def get_model():
17+
return mlflow.sklearn.load_model(model_uri)
18+
19+
20+
pipe = get_model()
1321

1422
app = FastAPI()
1523

@@ -26,16 +34,17 @@ class InferenceInput(BaseModel):
2634

2735
@app.get("/")
2836
async def read_root():
29-
return {"Hello": "Visit /docs for API documentation"}
37+
return {"message": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs"}
3038

3139

3240
@app.post("/predict")
3341
async def predict_one(data: InferenceInput) -> dict[str, float]:
34-
df_input = pd.DataFrame([data.dict()])
35-
36-
prediction = pipe.predict(df_input)
37-
38-
return {"prediction": prediction[0]}
42+
try:
43+
df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
44+
prediction = pipe.predict(df_input)
45+
return {"prediction": prediction[0]}
46+
except Exception as e:
47+
raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
3948

4049

4150
if __name__ == "__main__":

0 commit comments

Comments
 (0)