Skip to content

Commit f44c482

Browse files
committed
add fastapi pydantic output item
1 parent ca02afc commit f44c482

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/make_api/app/main.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import logging
12
import os
23
from functools import lru_cache
4+
from typing import Literal
35

46
import mlflow
57
import pandas as pd
68
import uvicorn
7-
from fastapi import FastAPI, HTTPException
9+
from fastapi import FastAPI
810
from pydantic import BaseModel
911

12+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
13+
logger = logging.getLogger(__name__)
14+
1015
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
1116

1217
model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
@@ -20,7 +25,7 @@ def get_model(model_uri):
2025
app = FastAPI()
2126

2227

23-
class InferenceInput(BaseModel):
28+
class PredictionInput(BaseModel):
2429
passenger_count: int
2530
trip_type: int
2631
congestion_surcharge: float
@@ -30,20 +35,44 @@ class InferenceInput(BaseModel):
3035
vendor_id: str
3136

3237

38+
class OutputItem(BaseModel):
39+
prediction_input: PredictionInput
40+
prediction: float
41+
status: Literal["success", "warning", "failure"]
42+
message: str
43+
44+
3345
@app.get("/")
3446
async def read_root():
3547
return {"message": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs"}
3648

3749

3850
@app.post("/predict")
39-
async def predict_one(data: InferenceInput) -> dict[str, float]:
51+
async def predict_one(data: PredictionInput) -> OutputItem:
4052
try:
4153
df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
54+
logger.info(f"[Prediction Input] Received input: {df_input}")
4255
prediction = get_model(model_uri).predict(df_input)
4356

44-
return {"prediction": prediction[0]}
57+
if prediction[0] <= 0:
58+
logger.error("[Prediction Output] Prediction failed: Negative prediction")
59+
return OutputItem(
60+
prediction_input=data,
61+
prediction=0.0,
62+
status="warning",
63+
message="Prediction failed: Negative prediction. Check your inputs.",
64+
)
65+
else:
66+
logger.info(f"[Prediction Output] Prediction: {prediction[0]}")
67+
return OutputItem(
68+
prediction_input=data, prediction=prediction[0], status="success", message="Prediction successful"
69+
)
70+
4571
except Exception as e:
46-
raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
72+
logger.error(f"[Prediction Output] Prediction failed: {str(e)}")
73+
return OutputItem(
74+
prediction_input=data, prediction=-1.0, status="failure", message=f"Prediction failed: {str(e)}"
75+
)
4776

4877

4978
if __name__ == "__main__":

0 commit comments

Comments
 (0)