1- # import os
2- # from functools import lru_cache
3-
4- # import mlflow
5- # import pandas as pd
6- # import uvicorn
7- # from fastapi import FastAPI, HTTPException
8- # from pydantic import BaseModel
9-
10- # mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
11-
12- # model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
13-
14-
15- # # @lru_cache
16- # # def get_model(model_uri):
17- # # return mlflow.sklearn.load_model(model_uri)
18-
19-
20- # app = FastAPI()
21-
22-
23- # class InferenceInput(BaseModel):
24- # passenger_count: int
25- # trip_type: int
26- # congestion_surcharge: float
27- # mean_distance: float
28- # mean_duration: float
29- # rush_hour: int
30- # vendor_id: str
31-
32-
33- # @app.get("/")
34- # async def read_root():
35- # return {"message": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs"}
36-
37-
38- # @app.post("/predict")
39- # async def predict_one(data: InferenceInput) -> dict[str, float]:
40- # try:
41- # df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
42- # # prediction = get_model(model_uri).predict(df_input)
43- # pipe = mlflow.sklearn.load_model(model_uri)
44- # prediction = pipe.predict(df_input)
45- # return {"prediction": prediction[0]}
46- # except Exception as e:
47- # raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
48-
49-
50- # if __name__ == "__main__":
51- # uvicorn.run(app, host="0.0.0.0", port=8000)
521import os
2+ from functools import lru_cache
533
544import mlflow
555import pandas as pd
566import uvicorn
57- from fastapi import FastAPI
7+ from fastapi import FastAPI, HTTPException
588from pydantic import BaseModel
599
6010mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
61- print(os.getenv("ASD"))
11+
6212model_uri = "models:/taxi_fare_prediction.taxi_fare_model@latest-model"
6313
14+
15+ @lru_cache
16+ def get_model(model_uri):
17+ return mlflow.sklearn.load_model(model_uri)
18+
19+
6420app = FastAPI()
6521
6622
@@ -76,17 +32,18 @@ class InferenceInput(BaseModel):
7632
7733@app.get("/")
7834async def read_root():
79- return {"Hello ": "Visit /docs for API documentation "}
35+ return {"message ": "Welcome to the Taxi Fare Prediction API!", "docs": "/docs "}
8036
8137
8238@app.post("/predict")
8339async def predict_one(data: InferenceInput) -> dict[str, float]:
84- pipe = mlflow.sklearn.load_model(model_uri)
85- df_input = pd.DataFrame([data.dict()])
86-
87- prediction = pipe.predict(df_input)
40+ try:
41+ df_input = pd.DataFrame([data.dict()]) # pd df might be overkill
42+ prediction = get_model(model_uri).predict(df_input)
8843
89- return {"prediction": prediction[0]}
44+ return {"prediction": prediction[0]}
45+ except Exception as e:
46+ raise HTTPException(status_code=400, detail=f"Prediction failed: {str(e)}") from e
9047
9148
9249if __name__ == "__main__":
0 commit comments