Skip to content

Commit 6dcda22

Browse files
authored
Update app.py
1 parent c977f35 commit 6dcda22

File tree

1 file changed

+56
-9
lines changed

1 file changed

+56
-9
lines changed

app.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import streamlit as st
22
import pandas as pd
3+
import plotly.graph_objects as go
34
from Archives.data_processing import clean_data, split_data
45
from Archives.model_training import get_models, train_and_evaluate_models
56
from Archives.visualisation import plot_model_comparison, plot_confusion_matrices
@@ -15,13 +16,14 @@
1516
text_color = "#FFFFFF"
1617
button_color = "#444"
1718
success_text_color = "#FFFFFF"
19+
plotly_template = "plotly_dark"
1820
else:
1921
bg_color = "#F8F4E1"
2022
text_color = "#000000"
2123
button_color = "#333333"
2224
success_text_color = "#000000"
25+
plotly_template = "plotly_white"
2326

24-
# Custom CSS
2527
st.markdown(f"""
2628
<style>
2729
.stApp {{
@@ -43,19 +45,20 @@
4345
div[data-testid="stAlertContainer"] p {{
4446
color: {success_text_color} !important;
4547
}}
48+
h2, h3, h4, h5, h6, .stMarkdown {{
49+
color: {text_color} !important;
50+
}}
51+
.css-10trblm, .stMetric label, .stMetric div {{
52+
color: {text_color} !important;
53+
}}
4654
</style>
4755
""", unsafe_allow_html=True)
4856

49-
# Sidebar title
5057
st.sidebar.title("📁 Upload and Train")
5158

52-
# File upload
5359
uploaded_file = st.sidebar.file_uploader("Upload a preprocessed dataset (CSV)", type=["csv"])
54-
55-
# Default path
5660
default_path = "Loan-Creaditworthiness-classification-main/data/Preprocessed/final.csv"
5761

58-
# Load dataset
5962
if uploaded_file:
6063
df = pd.read_csv(uploaded_file)
6164
st.sidebar.success("✅ Custom dataset loaded successfully!")
@@ -67,19 +70,30 @@
6770
st.sidebar.error(f"Failed to load default dataset: {e}")
6871
df = None
6972

70-
# Training and output
71-
if df is not None and st.sidebar.button("🚀 Train Models"):
73+
use_best_model = st.sidebar.checkbox("🚀 Use Best Model Automatically")
74+
75+
if df is not None and st.sidebar.button("⚙️ Train Models"):
7276
with st.spinner("Training in progress... Please wait ⏳"):
7377
df_cleaned = clean_data(df)
7478
X_train, X_test, y_train, y_test = split_data(df_cleaned, target_column='high_risk_applicant')
7579
models = get_models()
7680
results, predictions = train_and_evaluate_models(models, X_train, X_test, y_train, y_test)
81+
7782
figs = plot_model_comparison(results)
7883
confusion_figs = plot_confusion_matrices(predictions)
7984

85+
model_scores = {name: score['Accuracy'] for name, score in results.items()}
86+
model_dict = {name: predictions[name][1] for name in predictions}
87+
best_model_name = max(model_scores, key=model_scores.get)
88+
best_model = model_dict[best_model_name]
89+
90+
selected_model_name = best_model_name
91+
selected_model = best_model
92+
st.success(f"✅ Best model auto-selected: **{best_model_name}** (Accuracy: **{model_scores[best_model_name]:.2f}**)")
93+
8094
st.success("🎉 Training completed!")
8195

82-
st.subheader("📊 Model Performance")
96+
st.subheader("📊 Model Performance (Bar Graphs)")
8397
for i in range(0, len(figs), 2):
8498
cols = st.columns(2)
8599
for j, fig in enumerate(figs[i:i+2]):
@@ -90,3 +104,36 @@
90104
cols = st.columns(2)
91105
for j, fig in enumerate(confusion_figs[i:i+2]):
92106
cols[j].pyplot(fig)
107+
108+
st.subheader("📊 Visual Comparison of All Models")
109+
metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
110+
data = {metric: [results[model].get(metric, 0) for model in results] for metric in metrics}
111+
model_names = list(results.keys())
112+
113+
fig = go.Figure()
114+
for metric in metrics:
115+
fig.add_trace(go.Bar(
116+
y=model_names,
117+
x=data[metric],
118+
name=metric,
119+
orientation='h'
120+
))
121+
122+
fig.update_layout(
123+
barmode='group',
124+
title="Model Performance Comparison",
125+
xaxis_title="Score",
126+
yaxis_title="Models",
127+
template=plotly_template,
128+
height=500
129+
)
130+
131+
st.plotly_chart(fig, use_container_width=True)
132+
133+
st.markdown(f"<h3 style='color:{text_color};'>📈 Real-Time Performance Comparison</h3>", unsafe_allow_html=True)
134+
for name, score_dict in results.items():
135+
cols = st.columns(4)
136+
cols[0].markdown(f"<div style='color:{text_color}; font-weight:bold;'>{name}</div>", unsafe_allow_html=True)
137+
cols[1].markdown(f"<div style='color:{text_color};'>✅ Accuracy<br><span style='font-size:28px;'><b>{score_dict['Accuracy']:.3f}</b></span></div>", unsafe_allow_html=True)
138+
cols[2].markdown(f"<div style='color:{text_color};'>🎯 Precision<br><span style='font-size:28px;'><b>{score_dict['Precision']:.3f}</b></span></div>", unsafe_allow_html=True)
139+
cols[3].markdown(f"<div style='color:{text_color};'>📊 Recall<br><span style='font-size:28px;'><b>{score_dict['Recall']:.3f}</b></span></div>", unsafe_allow_html=True)

0 commit comments

Comments
 (0)