|
1 | 1 | import streamlit as st |
2 | 2 | import pandas as pd |
| 3 | +import plotly.graph_objects as go |
3 | 4 | from Archives.data_processing import clean_data, split_data |
4 | 5 | from Archives.model_training import get_models, train_and_evaluate_models |
5 | 6 | from Archives.visualisation import plot_model_comparison, plot_confusion_matrices |
|
15 | 16 | text_color = "#FFFFFF" |
16 | 17 | button_color = "#444" |
17 | 18 | success_text_color = "#FFFFFF" |
| 19 | + plotly_template = "plotly_dark" |
18 | 20 | else: |
19 | 21 | bg_color = "#F8F4E1" |
20 | 22 | text_color = "#000000" |
21 | 23 | button_color = "#333333" |
22 | 24 | success_text_color = "#000000" |
| 25 | + plotly_template = "plotly_white" |
23 | 26 |
|
24 | | -# Custom CSS |
25 | 27 | st.markdown(f""" |
26 | 28 | <style> |
27 | 29 | .stApp {{ |
|
43 | 45 | div[data-testid="stAlertContainer"] p {{ |
44 | 46 | color: {success_text_color} !important; |
45 | 47 | }} |
| 48 | + h2, h3, h4, h5, h6, .stMarkdown {{ |
| 49 | + color: {text_color} !important; |
| 50 | + }} |
| 51 | + .css-10trblm, .stMetric label, .stMetric div {{ |
| 52 | + color: {text_color} !important; |
| 53 | + }} |
46 | 54 | </style> |
47 | 55 | """, unsafe_allow_html=True) |
48 | 56 |
|
49 | | -# Sidebar title |
50 | 57 | st.sidebar.title("📁 Upload and Train") |
51 | 58 |
|
52 | | -# File upload |
53 | 59 | uploaded_file = st.sidebar.file_uploader("Upload a preprocessed dataset (CSV)", type=["csv"]) |
54 | | - |
55 | | -# Default path |
56 | 60 | default_path = "Loan-Creaditworthiness-classification-main/data/Preprocessed/final.csv" |
57 | 61 |
|
58 | | -# Load dataset |
59 | 62 | if uploaded_file: |
60 | 63 | df = pd.read_csv(uploaded_file) |
61 | 64 | st.sidebar.success("✅ Custom dataset loaded successfully!") |
|
67 | 70 | st.sidebar.error(f"Failed to load default dataset: {e}") |
68 | 71 | df = None |
69 | 72 |
|
70 | | -# Training and output |
71 | | -if df is not None and st.sidebar.button("🚀 Train Models"): |
| 73 | +use_best_model = st.sidebar.checkbox("🚀 Use Best Model Automatically") |
| 74 | + |
| 75 | +if df is not None and st.sidebar.button("⚙️ Train Models"): |
72 | 76 | with st.spinner("Training in progress... Please wait ⏳"): |
73 | 77 | df_cleaned = clean_data(df) |
74 | 78 | X_train, X_test, y_train, y_test = split_data(df_cleaned, target_column='high_risk_applicant') |
75 | 79 | models = get_models() |
76 | 80 | results, predictions = train_and_evaluate_models(models, X_train, X_test, y_train, y_test) |
| 81 | + |
77 | 82 | figs = plot_model_comparison(results) |
78 | 83 | confusion_figs = plot_confusion_matrices(predictions) |
79 | 84 |
|
| 85 | + model_scores = {name: score['Accuracy'] for name, score in results.items()} |
| 86 | + model_dict = {name: predictions[name][1] for name in predictions} |
| 87 | + best_model_name = max(model_scores, key=model_scores.get) |
| 88 | + best_model = model_dict[best_model_name] |
| 89 | + |
| 90 | + selected_model_name = best_model_name |
| 91 | + selected_model = best_model |
| 92 | + st.success(f"✅ Best model auto-selected: **{best_model_name}** (Accuracy: **{model_scores[best_model_name]:.2f}**)") |
| 93 | + |
80 | 94 | st.success("🎉 Training completed!") |
81 | 95 |
|
82 | | - st.subheader("📊 Model Performance") |
| 96 | + st.subheader("📊 Model Performance (Bar Graphs)") |
83 | 97 | for i in range(0, len(figs), 2): |
84 | 98 | cols = st.columns(2) |
85 | 99 | for j, fig in enumerate(figs[i:i+2]): |
|
90 | 104 | cols = st.columns(2) |
91 | 105 | for j, fig in enumerate(confusion_figs[i:i+2]): |
92 | 106 | cols[j].pyplot(fig) |
| 107 | + |
| 108 | + st.subheader("📊 Visual Comparison of All Models") |
| 109 | + metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score'] |
| 110 | + data = {metric: [results[model].get(metric, 0) for model in results] for metric in metrics} |
| 111 | + model_names = list(results.keys()) |
| 112 | + |
| 113 | + fig = go.Figure() |
| 114 | + for metric in metrics: |
| 115 | + fig.add_trace(go.Bar( |
| 116 | + y=model_names, |
| 117 | + x=data[metric], |
| 118 | + name=metric, |
| 119 | + orientation='h' |
| 120 | + )) |
| 121 | + |
| 122 | + fig.update_layout( |
| 123 | + barmode='group', |
| 124 | + title="Model Performance Comparison", |
| 125 | + xaxis_title="Score", |
| 126 | + yaxis_title="Models", |
| 127 | + template=plotly_template, |
| 128 | + height=500 |
| 129 | + ) |
| 130 | + |
| 131 | + st.plotly_chart(fig, use_container_width=True) |
| 132 | + |
| 133 | + st.markdown(f"<h3 style='color:{text_color};'>📈 Real-Time Performance Comparison</h3>", unsafe_allow_html=True) |
| 134 | + for name, score_dict in results.items(): |
| 135 | + cols = st.columns(4) |
| 136 | + cols[0].markdown(f"<div style='color:{text_color}; font-weight:bold;'>{name}</div>", unsafe_allow_html=True) |
| 137 | + cols[1].markdown(f"<div style='color:{text_color};'>✅ Accuracy<br><span style='font-size:28px;'><b>{score_dict['Accuracy']:.3f}</b></span></div>", unsafe_allow_html=True) |
| 138 | + cols[2].markdown(f"<div style='color:{text_color};'>🎯 Precision<br><span style='font-size:28px;'><b>{score_dict['Precision']:.3f}</b></span></div>", unsafe_allow_html=True) |
| 139 | + cols[3].markdown(f"<div style='color:{text_color};'>📊 Recall<br><span style='font-size:28px;'><b>{score_dict['Recall']:.3f}</b></span></div>", unsafe_allow_html=True) |
0 commit comments