Skip to content

Commit b14d4e1

Browse files
Update app.py
1 parent 3559963 commit b14d4e1

File tree

1 file changed

+116
-72
lines changed

1 file changed

+116
-72
lines changed

backend/app.py

Lines changed: 116 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from flask import Flask, request, jsonify, send_file
99
from flask_cors import CORS
1010
from sklearn.linear_model import LinearRegression
11+
from sklearn.metrics import mean_squared_error, r2_score
1112
import numpy as np
1213
import openai
1314
from datetime import datetime
@@ -17,81 +18,87 @@
1718
# === CONFIG ===
1819
BACKEND_BASE_URL = "https://ai-dslab-backend-cpf2feachnetbbck.westus-01.azurewebsites.net"
1920

20-
# Initialize Flask app
2121
app = Flask(__name__)
22-
CORS(app) # Enable CORS for all routes
22+
CORS(app)
2323

24-
# Use a secure temporary directory
2524
TEMP_DIR = tempfile.TemporaryDirectory()
2625
UPLOAD_FOLDER = TEMP_DIR.name
2726
PLOT_PATH = os.path.join(UPLOAD_FOLDER, "plot.png")
27+
FORECAST_PLOT_PATH = os.path.join(UPLOAD_FOLDER, "forecast_plot.png")
28+
REPORT_PATH = os.path.join(UPLOAD_FOLDER, "summary_report.txt")
2829

29-
# Cleanup temp directory on shutdown
3030
@atexit.register
3131
def cleanup_temp_dir():
3232
TEMP_DIR.cleanup()
3333

34-
# Capture logs
3534
log_stream = io.StringIO()
3635

3736
def log_print(*args):
3837
print(*args, file=log_stream)
3938
sys.stdout.flush()
4039

41-
# Root route for health check
4240
@app.route("/")
4341
def index():
44-
return jsonify({"message": "AI DataScience Backend is running on Azure."})
42+
return jsonify({"message": "AI DataScience Backend is running on Azure."})
4543

46-
# Handle file upload and processing
4744
@app.route("/upload", methods=["POST"])
4845
def upload_file():
4946
log_stream.truncate(0)
5047
log_stream.seek(0)
5148

5249
file = request.files.get("file")
53-
if file is None:
54-
return jsonify({"error": "No file uploaded"}), 400
50+
x_col = request.form.get("x_column")
51+
y_col = request.form.get("y_column")
52+
model_choice = request.form.get("model", "linear")
53+
54+
if file is None or not x_col or not y_col:
55+
return jsonify({"error": "Missing file or column selection."}), 400
5556

5657
filepath = os.path.join(UPLOAD_FOLDER, file.filename)
5758
file.save(filepath)
58-
log_print("📦 File uploaded:", file.filename)
59-
6059
df = pd.read_csv(filepath)
61-
df.dropna(inplace=True)
60+
61+
if x_col not in df.columns or y_col not in df.columns:
62+
return jsonify({"error": f"'{x_col}' or '{y_col}' not in dataset."}), 400
63+
64+
df = df[[x_col, y_col]].dropna()
6265
df.columns = ['X', 'Y']
63-
log_print("🔍 Cleaned Data:", df.head())
66+
log_print("Cleaned Data:\n", df.head())
6467

65-
# Plot
68+
# Scatter plot
6669
plt.figure()
6770
plt.scatter(df['X'], df['Y'])
6871
plt.xlabel('X')
6972
plt.ylabel('Y')
7073
plt.title('Scatter Plot')
7174
plt.savefig(PLOT_PATH)
7275
plt.close()
73-
log_print("📊 Scatter plot saved.")
7476

75-
# Fit model
76-
df['X'] = pd.to_datetime(df['X'], errors='coerce')
77-
df.dropna(inplace=True)
78-
X = df['X'].map(pd.Timestamp.toordinal).values.reshape(-1, 1)
79-
y = df['Y'].values
80-
81-
if len(X) == 0:
82-
log_print("❌ No valid data to fit the model.")
83-
return jsonify({
84-
"summary": "No valid data found.",
85-
"log": log_stream.getvalue(),
86-
"forecast": "N/A",
87-
"plot_url": None
88-
})
77+
# Try parsing X as datetime
78+
df['X_date'] = pd.to_datetime(df['X'], errors='coerce')
79+
use_dates = df['X_date'].notna().sum() >= len(df) // 2
80+
try:
81+
if use_dates:
82+
df = df.dropna(subset=['X_date'])
83+
X = df['X_date'].map(pd.Timestamp.toordinal).values.reshape(-1, 1)
84+
else:
85+
X = df['X'].astype(float).values.reshape(-1, 1)
86+
y = df['Y'].astype(float).values
87+
except:
88+
return jsonify({"error": "Failed to parse X or Y as numeric or date."}), 400
89+
90+
if model_choice == "linear":
91+
model = LinearRegression()
92+
else:
93+
return jsonify({"error": "Only linear regression supported for now."}), 400
8994

90-
model = LinearRegression()
9195
model.fit(X, y)
92-
log_print("🤖 Model trained.")
96+
y_pred = model.predict(X)
97+
r2 = r2_score(y, y_pred)
98+
mse = mean_squared_error(y, y_pred)
99+
log_print(f"Model Trained. R² = {r2:.4f}, MSE = {mse:.4f}")
93100

94-
# OpenAI summary
101+
# OpenAI Summary
95102
try:
96103
openai.api_key = os.getenv("OPENAI_API_KEY")
97104
client = openai.OpenAI(api_key=openai.api_key)
@@ -103,81 +110,118 @@ def upload_file():
103110
]
104111
)
105112
summary = response.choices[0].message.content
106-
log_print("🧠 OpenAI Summary generated.")
113+
log_print("Summary generated by OpenAI.")
107114
except Exception as e:
108115
summary = "OpenAI summarization failed."
109-
log_print("OpenAI error:", str(e))
116+
log_print("OpenAI error:", str(e))
110117

111118
return jsonify({
112119
"summary": summary,
113120
"log": log_stream.getvalue(),
114-
"forecast": "Submit future x-values below to get predictions.",
121+
"forecast": "Submit future values below to get predictions.",
122+
"r2_score": round(r2, 4),
123+
"mse": round(mse, 4),
115124
"plot_url": f"{BACKEND_BASE_URL}/plot.png"
116125
})
117126

118-
# Serve the generated plot
119-
@app.route("/plot.png")
120-
def serve_plot():
121-
return send_file(PLOT_PATH, mimetype="image/png")
122-
123-
# Handle prediction requests
124127
@app.route("/predict", methods=["POST"])
125128
def predict():
126129
future_x = request.form.get("future_x")
127130
if not future_x:
128-
return jsonify({
129-
"forecast": "No future values provided.",
130-
"log": log_stream.getvalue(),
131-
"plot_url": None
132-
})
131+
return jsonify({"forecast": "No future values provided."}), 400
133132

134133
try:
135-
values = [datetime.strptime(x.strip(), "%Y-%m-%d").toordinal()
136-
for x in future_x.split(",")]
137-
except ValueError:
138-
log_print("❌ Invalid date format. Use YYYY-MM-DD.")
139-
return jsonify({
140-
"forecast": "Invalid date format. Use YYYY-MM-DD.",
141-
"log": log_stream.getvalue(),
142-
"plot_url": None
143-
})
134+
values = future_x.split(",")
135+
numeric_vals, date_vals = [], []
136+
for x in values:
137+
try:
138+
date_vals.append(datetime.strptime(x.strip(), "%Y-%m-%d").toordinal())
139+
except:
140+
numeric_vals.append(float(x.strip()))
141+
values_parsed = np.array(date_vals if date_vals else numeric_vals).reshape(-1, 1)
142+
except Exception as e:
143+
log_print("Parsing future_x failed:", str(e))
144+
return jsonify({"forecast": "Invalid format for future values."}), 400
144145

145146
try:
146147
files = os.listdir(UPLOAD_FOLDER)
147-
if not files:
148-
raise FileNotFoundError("No uploaded file found.")
149-
latest_file = max(
148+
csv_file = max(
150149
[os.path.join(UPLOAD_FOLDER, f) for f in files if f.endswith(".csv")],
151150
key=os.path.getctime
152151
)
153-
df = pd.read_csv(latest_file)
154-
df.dropna(inplace=True)
152+
df = pd.read_csv(csv_file)
153+
df = df.dropna()
155154
df.columns = ['X', 'Y']
156-
df['X'] = pd.to_datetime(df['X'], errors='coerce')
157-
df.dropna(inplace=True)
158-
X = df['X'].map(pd.Timestamp.toordinal).values.reshape(-1, 1)
159-
y = df['Y'].values
155+
df['X_date'] = pd.to_datetime(df['X'], errors='coerce')
156+
use_dates = df['X_date'].notna().sum() >= len(df) // 2
157+
158+
if use_dates:
159+
df = df.dropna(subset=['X_date'])
160+
X = df['X_date'].map(pd.Timestamp.toordinal).values.reshape(-1, 1)
161+
else:
162+
X = df['X'].astype(float).values.reshape(-1, 1)
163+
y = df['Y'].astype(float).values
160164

161165
model = LinearRegression()
162166
model.fit(X, y)
163167

164-
predicted = model.predict(np.array(values).reshape(-1, 1))
168+
y_future = model.predict(values_parsed)
165169
result = {
166-
datetime.fromordinal(v).strftime("%Y-%m-%d"): round(p, 2)
167-
for v, p in zip(values, predicted)
170+
(datetime.fromordinal(int(x)) if use_dates else float(x)): round(p, 2)
171+
for x, p in zip(values_parsed.flatten(), y_future)
168172
}
169173

170-
log_print("🔮 Forecast complete.")
174+
# Generate second plot
175+
X_all = np.concatenate((X, values_parsed))
176+
x_min, x_max = X_all.min(), X_all.max()
177+
x_plot = np.linspace(x_min, x_max, 200).reshape(-1, 1)
178+
y_plot = model.predict(x_plot)
179+
180+
plt.figure()
181+
plt.scatter(X, y, label='Training Data', alpha=0.6)
182+
plt.plot(x_plot, y_plot, color='blue', label='Regression Line')
183+
plt.scatter(values_parsed, y_future, color='red', label='Forecast', marker='x')
184+
plt.legend()
185+
plt.xlabel('X')
186+
plt.ylabel('Y')
187+
plt.title('Forecast with Regression Line')
188+
plt.savefig(FORECAST_PLOT_PATH)
189+
plt.close()
190+
191+
# Write report
192+
with open(REPORT_PATH, 'w') as f:
193+
f.write("AI Forecast Report\\n")
194+
f.write("=================\\n\\n")
195+
f.write("Model: Linear Regression\\n")
196+
f.write(f"R² Score: {r2_score(y, model.predict(X)):.4f}\\n")
197+
f.write(f"MSE: {mean_squared_error(y, model.predict(X)):.4f}\\n")
198+
f.write("\\nForecast Results:\\n")
199+
for k, v in result.items():
200+
f.write(f"{k}: {v}\\n")
201+
171202
return jsonify({
172-
"forecast": result,
203+
"forecast": {str(k): v for k, v in result.items()},
173204
"log": log_stream.getvalue(),
174-
"plot_url": f"{BACKEND_BASE_URL}/plot.png"
205+
"plot_url": f"{BACKEND_BASE_URL}/plot.png",
206+
"forecast_plot_url": f"{BACKEND_BASE_URL}/forecast_plot.png"
175207
})
176208

177209
except Exception as e:
178-
log_print("Prediction failed:", str(e))
210+
log_print("Prediction failed:", str(e))
179211
return jsonify({
180212
"forecast": "Prediction failed.",
181213
"log": log_stream.getvalue(),
182214
"plot_url": None
183215
})
216+
217+
@app.route("/plot.png")
218+
def serve_plot():
219+
return send_file(PLOT_PATH, mimetype="image/png")
220+
221+
@app.route("/forecast_plot.png")
222+
def serve_forecast_plot():
223+
return send_file(FORECAST_PLOT_PATH, mimetype="image/png")
224+
225+
@app.route("/download-report", methods=["GET"])
226+
def download_report():
227+
return send_file(REPORT_PATH, as_attachment=True, download_name="report.txt")

0 commit comments

Comments
 (0)