-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
99 lines (81 loc) · 3.66 KB
/
main.py
File metadata and controls
99 lines (81 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import pandas as pd
import numpy as np
import os
# Initialize FastAPI app
app = FastAPI(
title="Outlet Performance Predictor API",
description="API for predicting outlet performance based on item details.",
version="1.0.0",
)
# --- 1. Load the Trained Model ---
MODEL_PATH = 'AAPOutletFinalfinal.joblib'
model_pipeline = None # Initialize to None
try:
if os.path.exists(MODEL_PATH):
model_pipeline = joblib.load(MODEL_PATH)
print("Model pipeline loaded successfully!")
else:
print(f"Error: Model file not found at {MODEL_PATH}. Prediction service will be unavailable.")
except Exception as e:
print(f"Error loading model pipeline: {e}. Prediction service will be unavailable.")
# --- 2. Define Input Data Schema ---
class ItemFeatures(BaseModel):
Item_Identifier: str # This field will be removed before passing to the model
Item_Weight: float
Item_Fat_Content: str
Item_Visibility: float
Item_Type: str
Item_MRP: float
Outlet_Identifier: str
Outlet_Establishment_Year: int
Outlet_Size: str = None
Outlet_Location_Type: str
Outlet_Type: str
# --- 3. Define the Prediction Endpoint ---
@app.post("/predict/outlet_performance")
async def predict_outlet_performance(item: ItemFeatures):
if model_pipeline is None:
raise HTTPException(status_code=500, detail="Model pipeline could not be loaded on server startup. Service unavailable.")
try:
input_data_dict = item.model_dump()
# --- IMPORTANT: Remove Item_Identifier before creating DataFrame for the model ---
item_identifier_value = input_data_dict.pop('Item_Identifier')
input_df = pd.DataFrame([input_data_dict])
# --- Feature Engineering for Inference (as in your notebook) ---
input_df['Outlet_Age'] = 2010 - input_df['Outlet_Establishment_Year']
input_df['Item_Sales_Deviation'] = 0.0 # Placeholder for new, single items
# --- DIAGNOSTIC PRINTS (FOR DEBUGGING) ---
print("\n--- DataFrame sent to pipeline (input_df) ---")
print(input_df.head())
print(input_df.dtypes)
print("-------------------------------------------\n")
# --- END DIAGNOSTIC PRINTS ---
prediction_numeric = model_pipeline.predict(input_df)
label_map = {0: 'Poor', 1: 'Medium', 2: 'Good'}
predicted_label = label_map.get(prediction_numeric[0], "Unknown")
return {"predicted_outlet_performance": predicted_label}
except KeyError as e:
print(f"Prediction input data error: Missing expected feature - {e}")
raise HTTPException(status_code=400, detail=f"Invalid input data: A required feature is missing or misnamed - {e}")
except ValueError as e:
print(f"Prediction pipeline error: {e}")
raise HTTPException(status_code=422, detail=f"Data validation or transformation error during prediction: {e}. Check input values and categories.")
except Exception as e:
print(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(status_code=500, detail=f"Internal Server Error during prediction: {e}")
# --- Optional: Root endpoint for API health check ---
@app.get("/")
async def root():
return {"message": "Outlet Performance Predictor API is running! Access docs at /docs"}
# Commands to run
# 1. pip install "fastapi[standard]" or conda install "fastapi[standard]"
# 2. fastapi dev main.py
# You will be able to find your docs at https://127.0.0.1/docs/
@app.get("/api/get_result")
def predict(request):
data = request.json()
postal_code = data["postal_code"]
return {"response": "Hello!"}