Skip to content

Commit 6484daf

Browse files
committed
add checks in predict api
1 parent b243fa5 commit 6484daf

File tree

1 file changed

+16
-59
lines changed

1 file changed

+16
-59
lines changed

src/make_api/app/main.py

Lines changed: 16 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,22 @@
1-
# import os
2-
# from functools import lru_cache
3-
4-
# import mlflow
5-
# import pandas as pd
6-
# import uvicorn
7-
# from fastapi import FastAPI, HTTPException
8-
# from pydantic import BaseModel
9-
10-
# mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
11-
12-
# model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
13-
14-
15-
# # @lru_cache
16-
# # def get_model(model_uri):
17-
# # return mlflow.sklearn.load_model(model_uri)
18-
19-
20-
# app = FastAPI()
21-
22-
23-
# class InferenceInput(BaseModel):
24-
# passenger_count: int
25-
# trip_type: int
26-
# congestion_surcharge: float
27-
# mean_distance: float
28-
# mean_duration: float
29-
# rush_hour: int
30-
# vendor_id: str
31-
32-
33-
# @app.get("/")
34-
# async def read_root():
35-
# return {"message": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs"}
36-
37-
38-
# @app.post("/predict")
39-
# async def predict_one(data: InferenceInput) -> dict[str, float]:
40-
# try:
41-
# df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
42-
# # prediction = get_model(model_uri).predict(df_input)
43-
# pipe = mlflow.sklearn.load_model(model_uri)
44-
# prediction = pipe.predict(df_input)
45-
# return {"prediction": prediction[0]}
46-
# except Exception as e:
47-
# raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
48-
49-
50-
# if __name__ == "__main__":
51-
# uvicorn.run(app, host="0.0.0.0", port=8000)
521
import os
2+
from functools import lru_cache
533

544
import mlflow
555
import pandas as pd
566
import uvicorn
57-
from fastapi import FastAPI
7+
from fastapi import FastAPI, HTTPException
588
from pydantic import BaseModel
599

6010
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
61-
print(os.getenv("ASD"))
11+
6212
model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
6313

14+
15+
@lru_cache
16+
def get_model(model_uri):
17+
return mlflow.sklearn.load_model(model_uri)
18+
19+
6420
app = FastAPI()
6521

6622

@@ -76,17 +32,18 @@ class InferenceInput(BaseModel):
7632

7733
@app.get("/")
7834
async def read_root():
79-
return {"Hello": "Visit /docs for API documentation"}
35+
return {"message": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs"}
8036

8137

8238
@app.post("/predict")
8339
async def predict_one(data: InferenceInput) -> dict[str, float]:
84-
pipe = mlflow.sklearn.load_model(model_uri)
85-
df_input = pd.DataFrame([data.dict()])
86-
87-
prediction = pipe.predict(df_input)
40+
try:
41+
df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
42+
prediction = get_model(model_uri).predict(df_input)
8843

89-
return {"prediction": prediction[0]}
44+
return {"prediction": prediction[0]}
45+
except Exception as e:
46+
raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
9047

9148

9249
if __name__ == "__main__":

0 commit comments

Comments
 (0)