Skip to content

Commit 99697a1

Browse files
committed
feat: plotting and SVI convergence analysis
1 parent fc2e160 commit 99697a1

File tree

1 file changed

+249
-26
lines changed

1 file changed

+249
-26
lines changed

src/bayesian_regression.py

Lines changed: 249 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@
4545
# ---------------------------------------------------------- PLOTTING SETTINGS ------------------------------------------------------ #
4646
plt.style.use("seaborn-v0_8-whitegrid") #
4747
plt.rcParams.update({ #
48-
"font.size" : 14, # Base font size for all text in the plot. #
49-
"axes.titlesize" : 16, # Title size for axes. #
50-
"axes.labelsize" : 16, # Axis labels size. #
51-
"xtick.labelsize" : 16, # X-axis tick labels size. #
52-
"ytick.labelsize" : 16, # Y-axis tick labels size. #
53-
"legend.fontsize" : 18, # Legend font size for all text in the legend. #
54-
"figure.titlesize": 20, # Overall figure title size for all text in the figure. #
48+
"font.size" : 16, # Base font size for all text in the plot. #
49+
"axes.titlesize" : 18, # Title size for axes. #
50+
"axes.labelsize" : 18, # Axis labels size. #
51+
"xtick.labelsize" : 18, # X-axis tick labels size. #
52+
"ytick.labelsize" : 18, # Y-axis tick labels size. #
53+
"legend.fontsize" : 20, # Legend font size for all text in the legend. #
54+
"figure.titlesize": 24, # Overall figure title size for all text in the figure. #
5555
}) #
5656
sns.set_theme(style="whitegrid", context="paper") # Set seaborn theme for additional aesthetics and context. #
5757
plt.rcParams["figure.figsize"] = (6, 4) # Set default figure size for all plots to 6x4. #
@@ -797,18 +797,19 @@ def visualize_prior_vs_posterior(self, output_dir, experiment_name, num_samples=
797797
# For 'flat' prior or any other prior not in prior_metrics, use a default value
798798
prior_std = 1.0 # Default value
799799
print(f"Using default prior std={prior_std} for {self.prior_type} prior")
800-
# Calculate errors to identify good and poor examples
800+
# Calculate errors to identify examples
801801
if self.prior_type in prior_test:
802802
prior_mean = prior_test[self.prior_type]
803803
prior_errors = np.abs(prior_mean - y_test)
804804
posterior_errors = np.abs(y_pred_mean - y_test)
805805
# Calculate improvement (positive means posterior is better)
806806
improvements = prior_errors - posterior_errors
807-
# Find one good example (high improvement) and one poor example (low improvement)
807+
# Find a good example (high improvement)
808808
good_idx = np.argmax(improvements)
809-
poor_idx = np.argmin(improvements)
809+
# Look for cases where posterior uncertainty is high
810+
worse_idx = np.argmax(y_pred_std)
810811
# Use these two specific indices
811-
indices = [good_idx, poor_idx]
812+
indices = [good_idx, worse_idx]
812813
else:
813814
# If we don't have prior data, just pick first and last examples
814815
indices = [0, len(y_test)-1]
@@ -847,7 +848,7 @@ def visualize_prior_vs_posterior(self, output_dir, experiment_name, num_samples=
847848
if i == 0:
848849
ax.set_title(f'Good Example: Prior vs Posterior (True Angle: {true_angle:.2f}°)')
849850
else:
850-
ax.set_title(f'Poor Example: Prior vs Posterior (True Angle: {true_angle:.2f}°)')
851+
ax.set_title(f'High Uncertainty Example: Prior vs Posterior (True Angle: {true_angle:.2f}°)')
851852
ax.grid(True, alpha=0.3)
852853
ax.legend()
853854
# Annotate statistics
@@ -907,12 +908,12 @@ def plot_posterior_predictive(self, output_dir, experiment_name):
907908
label='90% Credible Interval')
908909
plt.fill_between(range(len(y_test)), y_pred_25, y_pred_75, alpha=0.5, color='blue',
909910
label='50% Credible Interval')
910-
plt.plot(range(len(y_test)), y_pred_50, 'orange', linewidth=3, label='Median Prediction')
911-
plt.plot(range(len(y_test)), y_test_sorted, 'ro', makersize=10, label='True Angles')
912-
plt.xlabel('Test Point Index (sorted by true angle)')
913-
plt.ylabel('Angle (degrees)')
914-
plt.title('Posterior Predictive Distribution')
915-
plt.legend()
911+
plt.plot(range(len(y_test)), y_pred_50, 'orange', linewidth=2.5, label='Median Prediction')
912+
plt.plot(range(len(y_test)), y_test_sorted, 'ro', markersize=8, label='True Angles')
913+
plt.xlabel('Test Point Index (sorted by true angle)', fontsize=16)
914+
plt.ylabel('Angle (degrees)', fontsize=16)
915+
plt.title('Posterior Predictive Distribution', fontsize=20)
916+
plt.legend(fontsize=16)
916917
plt.grid(True, alpha=0.3)
917918
plt.tight_layout()
918919
plt.savefig(os.path.join(vis_dir, "posterior_predictive.png"), dpi=300, bbox_inches='tight')
@@ -1313,6 +1314,225 @@ def _extract_features_for_sample(self, phasor1, phasor2, rssi1, rssi2, D, W, wav
13131314
features = np.append(features, D)
13141315
return features
13151316

1317+
def analyze_svi_convergence_metrics(models_dict, results_dir):
1318+
"""
1319+
Analyze SVI convergence metrics across different models/priors and generate a tabular report.
1320+
1321+
Parameters:
1322+
- models_dict [dict] : Dictionary of trained BayesianAoARegressor models
1323+
- results_dir [str] : Directory to save the convergence analysis results
1324+
1325+
Returns:
1326+
- convergence_metrics [dict]: Dictionary with convergence metrics by model
1327+
"""
1328+
print("\n=== ANALYZING SVI CONVERGENCE ACROSS PRIORS ===")
1329+
1330+
# Output directory
1331+
os.makedirs(results_dir, exist_ok=True)
1332+
1333+
# Store convergence metrics
1334+
convergence_metrics = {}
1335+
1336+
# Define convergence criteria
1337+
def detect_convergence(losses, window_size=100, threshold_pct=0.5):
1338+
"""Detect when loss has converged (stabilized)"""
1339+
if len(losses) <= window_size:
1340+
return len(losses) - 1 # Not enough data, return last index
1341+
1342+
# Calculate rolling standard deviation as percentage of mean
1343+
rolling_std_pct = []
1344+
for i in range(window_size, len(losses)):
1345+
window = losses[i-window_size:i]
1346+
std_pct = (np.std(window) / abs(np.mean(window))) * 100
1347+
rolling_std_pct.append(std_pct)
1348+
1349+
# Find first point where std_pct is below threshold
1350+
for i, std_pct in enumerate(rolling_std_pct):
1351+
if std_pct < threshold_pct:
1352+
return i + window_size # Add window_size to get the actual index
1353+
1354+
return len(losses) - 1 # No convergence detected, return last index
1355+
1356+
# Define function to fit exponential decay
1357+
def fit_convergence_rate(iterations, losses):
1358+
"""Fit exponential decay model to estimate convergence rate"""
1359+
from scipy.optimize import curve_fit
1360+
1361+
def exp_decay(x, a, b, c):
1362+
return a * np.exp(-b * x) + c
1363+
1364+
try:
1365+
# Normalize iterations to [0, 1] for numerical stability
1366+
x_norm = iterations / np.max(iterations)
1367+
params, _ = curve_fit(exp_decay, x_norm, losses,
1368+
p0=[losses[0]-losses[-1], 5, losses[-1]],
1369+
bounds=([0, 0, -np.inf], [np.inf, 100, np.inf]))
1370+
1371+
convergence_rate = params[1]
1372+
return convergence_rate
1373+
except Exception as e:
1374+
print(f"Could not fit convergence rate: {e}")
1375+
return None
1376+
1377+
# Analyze each model
1378+
for model_name, model in models_dict.items():
1379+
if model.train_summary is None or 'losses' not in model.train_summary:
1380+
print(f"Skipping {model_name} - no loss data available")
1381+
continue
1382+
1383+
losses = np.array(model.train_summary['losses'])
1384+
iterations = np.arange(1, len(losses) + 1)
1385+
1386+
# Calculate key convergence metrics
1387+
1388+
# 1. Iterations to convergence (using our detection function)
1389+
iter_to_converge = detect_convergence(losses)
1390+
1391+
# 2. Loss reduction (initial vs. final)
1392+
initial_loss = losses[0]
1393+
final_loss = losses[-1]
1394+
abs_reduction = initial_loss - final_loss
1395+
pct_reduction = (abs_reduction / abs(initial_loss)) * 100 if initial_loss != 0 else 0
1396+
1397+
# 3. Convergence rate (from exponential fit)
1398+
conv_rate = fit_convergence_rate(iterations, losses)
1399+
1400+
# 4. Stability of final 10% of training
1401+
final_window = losses[-int(0.1*len(losses)):]
1402+
final_stability = (np.std(final_window) / abs(np.mean(final_window))) * 100
1403+
1404+
# 5. Early-stage (0-20%) vs late-stage (80-100%) loss reduction rate
1405+
early_loss_drop = losses[0] - losses[int(0.2*len(losses))]
1406+
early_rate = early_loss_drop / (0.2*len(losses))
1407+
1408+
late_window_start = int(0.8*len(losses))
1409+
late_loss_drop = losses[late_window_start] - losses[-1]
1410+
late_rate = late_loss_drop / (len(losses) - late_window_start)
1411+
1412+
# Store metrics
1413+
convergence_metrics[model_name] = {
1414+
'prior_type': model.prior_type,
1415+
'feature_mode': model.feature_mode,
1416+
'iterations': len(losses),
1417+
'iter_to_converge': iter_to_converge,
1418+
'pct_to_converge': (iter_to_converge / len(losses)) * 100,
1419+
'initial_loss': initial_loss,
1420+
'final_loss': final_loss,
1421+
'abs_reduction': abs_reduction,
1422+
'pct_reduction': pct_reduction,
1423+
'convergence_rate': conv_rate,
1424+
'final_stability': final_stability,
1425+
'early_reduction_rate': early_rate,
1426+
'late_reduction_rate': late_rate
1427+
}
1428+
1429+
# Create a tabular summary for paper
1430+
columns = ['prior_type', 'feature_mode', 'iter_to_converge', 'pct_to_converge',
1431+
'pct_reduction', 'convergence_rate', 'final_stability']
1432+
1433+
# 1. Group by prior type
1434+
prior_groups = {}
1435+
for name, metrics in convergence_metrics.items():
1436+
prior = metrics['prior_type']
1437+
if prior not in prior_groups:
1438+
prior_groups[prior] = []
1439+
prior_groups[prior].append((name, metrics))
1440+
1441+
# Write summary tables
1442+
with open(os.path.join(results_dir, "svi_convergence_summary.txt"), 'w') as f:
1443+
f.write("SVI CONVERGENCE ANALYSIS BY PRIOR TYPE\n")
1444+
f.write("=====================================\n\n")
1445+
1446+
# Table 1: Iterations to convergence by prior type (averaged across feature modes)
1447+
f.write("TABLE 1: CONVERGENCE METRICS BY PRIOR TYPE\n")
1448+
f.write("Prior Type | Iterations to Converge | % of Total | Convergence Rate | Loss Reduction (%)\n")
1449+
f.write("----------|----------------------|-----------|-----------------|------------------\n")
1450+
1451+
for prior in sorted(prior_groups.keys()):
1452+
metrics_list = [m for _, m in prior_groups[prior]]
1453+
avg_iter = np.mean([m['iter_to_converge'] for m in metrics_list])
1454+
avg_pct = np.mean([m['pct_to_converge'] for m in metrics_list])
1455+
avg_rate = np.mean([m['convergence_rate'] for m in metrics_list if m['convergence_rate'] is not None])
1456+
avg_reduction = np.mean([m['pct_reduction'] for m in metrics_list])
1457+
1458+
f.write(f"{prior.upper():10} | {avg_iter:.1f} | {avg_pct:.1f}% | {avg_rate:.3f} | {avg_reduction:.1f}%\n")
1459+
1460+
f.write("\n\n")
1461+
1462+
# Table 2: Detailed metrics for each configuration
1463+
f.write("TABLE 2: DETAILED CONVERGENCE METRICS BY MODEL CONFIGURATION\n")
1464+
f.write("Model | Prior | Features | Iter to Conv | % of Total | Conv Rate | Reduction | Stability\n")
1465+
f.write("------|-------|----------|-------------|-----------|----------|-----------|----------\n")
1466+
1467+
for name, metrics in sorted(convergence_metrics.items()):
1468+
prior = metrics['prior_type'].upper()
1469+
features = metrics['feature_mode']
1470+
iter_conv = metrics['iter_to_converge']
1471+
pct_conv = metrics['pct_to_converge']
1472+
conv_rate = metrics['convergence_rate'] if metrics['convergence_rate'] is not None else "N/A"
1473+
reduction = metrics['pct_reduction']
1474+
stability = metrics['final_stability']
1475+
1476+
if isinstance(conv_rate, float):
1477+
conv_rate_str = f"{conv_rate:.3f}"
1478+
else:
1479+
conv_rate_str = "N/A"
1480+
1481+
f.write(f"{name:6} | {prior:5} | {features:8} | {iter_conv:11.0f} | {pct_conv:9.1f}% | "
1482+
f"{conv_rate_str:8} | {reduction:8.1f}% | {stability:8.2f}%\n")
1483+
1484+
f.write("\n\n")
1485+
1486+
# Interpretation and recommendations
1487+
f.write("INTERPRETATION OF METRICS:\n")
1488+
f.write("-------------------------\n")
1489+
f.write("- Iterations to Converge: Lower is better, faster convergence\n")
1490+
f.write("- % of Total: Percentage of total training iterations needed to converge\n")
1491+
f.write("- Convergence Rate: Higher is better, faster exponential decay of loss\n")
1492+
f.write("- Loss Reduction: Higher is better, more improvement from initial loss\n")
1493+
f.write("- Stability: Lower is better, less variation in final 10% of training\n\n")
1494+
1495+
# Summary of findings
1496+
f.write("SUMMARY OF FINDINGS:\n")
1497+
f.write("------------------\n")
1498+
1499+
# Find fastest converging prior
1500+
avg_conv_by_prior = {}
1501+
for prior, models in prior_groups.items():
1502+
avg_conv_by_prior[prior] = np.mean([m['pct_to_converge'] for _, m in models])
1503+
1504+
fastest_prior = min(avg_conv_by_prior.items(), key=lambda x: x[1])[0]
1505+
slowest_prior = max(avg_conv_by_prior.items(), key=lambda x: x[1])[0]
1506+
1507+
f.write(f"1. Fastest converging prior: {fastest_prior.upper()} "
1508+
f"({avg_conv_by_prior[fastest_prior]:.1f}% of iterations)\n")
1509+
f.write(f"2. Slowest converging prior: {slowest_prior.upper()} "
1510+
f"({avg_conv_by_prior[slowest_prior]:.1f}% of iterations)\n")
1511+
1512+
# Prior with highest convergence rate
1513+
avg_rate_by_prior = {}
1514+
for prior, models in prior_groups.items():
1515+
rates = [m['convergence_rate'] for _, m in models if m['convergence_rate'] is not None]
1516+
if rates:
1517+
avg_rate_by_prior[prior] = np.mean(rates)
1518+
1519+
if avg_rate_by_prior:
1520+
highest_rate_prior = max(avg_rate_by_prior.items(), key=lambda x: x[1])[0]
1521+
f.write(f"3. Prior with highest convergence rate: {highest_rate_prior.upper()} "
1522+
f"(rate = {avg_rate_by_prior[highest_rate_prior]:.3f})\n")
1523+
1524+
# Prior with most stable convergence
1525+
avg_stability_by_prior = {}
1526+
for prior, models in prior_groups.items():
1527+
avg_stability_by_prior[prior] = np.mean([m['final_stability'] for _, m in models])
1528+
1529+
most_stable_prior = min(avg_stability_by_prior.items(), key=lambda x: x[1])[0]
1530+
f.write(f"4. Most stable convergence: {most_stable_prior.upper()} "
1531+
f"(final stability = {avg_stability_by_prior[most_stable_prior]:.2f}%)\n")
1532+
1533+
print(f"SVI convergence analysis completed and saved to: {os.path.join(results_dir, 'svi_convergence_summary.txt')}")
1534+
return convergence_metrics
1535+
13161536
def train_bayesian_models(data_manager, results_dir, num_epochs=10000):
13171537
"""
13181538
Train multiple Bayesian AoA regression models with different priors and feature sets.
@@ -1347,6 +1567,9 @@ def train_bayesian_models(data_manager, results_dir, num_epochs=10000):
13471567
# Dictionary to store results
13481568
models = {}
13491569
results = {}
1570+
# Create convergence analysis directory
1571+
convergence_dir = os.path.join(results_dir, "svi_convergence_analysis")
1572+
os.makedirs(convergence_dir, exist_ok=True)
13501573
# Train models for each configuration
13511574
for config in configs:
13521575
print(f"\n--- Training Bayesian model with {config['prior']} prior, {config['features']} features ---")
@@ -1363,18 +1586,18 @@ def train_bayesian_models(data_manager, results_dir, num_epochs=10000):
13631586
results[config['name']] = train_results
13641587
# Visualize results
13651588
model.visualize_results(results_dir, config['name'])
1366-
# Generate new visualizations
13671589
model.render_model_and_guide(results_dir, config['name'])
13681590
model.plot_posterior_predictive(results_dir, config['name'])
13691591
model.plot_uncertainty_calibration(results_dir, config['name'])
1370-
# Visualize prior vs posterior
13711592
models[config['name']].visualize_prior_vs_posterior(results_dir, config['name'])
13721593
model.visualize_weight_distributions(results_dir, config['name'])
13731594
model.analyze_with_posterior_weights(data_manager, results_dir, config['name'])
13741595
print(f"Completed training {config['name']}")
1596+
# Perofrm SVI convergence analysis
1597+
convergence_metrics = analyze_svi_convergence_metrics(models, convergence_dir)
13751598
# Create comparison visualizations
13761599
compare_bayesian_models(models, results, results_dir)
1377-
return {"models": models, "results": results}
1600+
return {"models": models, "results": results, "convergence": convergence_metrics}
13781601

13791602
def compare_bayesian_models(models, results, output_dir):
13801603
"""
@@ -1461,11 +1684,11 @@ def compare_bayesian_models(models, results, output_dir):
14611684
rects1 = ax.bar(x - width/2, sorted_maes, width, label='MAE')
14621685
rects2 = ax.bar(x + width/2, sorted_rmses, width, label='RMSE')
14631686
# Add labels and title
1464-
ax.set_ylabel('Error (degrees)')
1465-
ax.set_title('Bayesian AoA Model Performance Comparison')
1687+
ax.set_ylabel('Error (degrees)', fontsize=16)
1688+
ax.set_title('Bayesian AoA Model Performance Comparison', fontsize=20)
14661689
ax.set_xticks(x)
1467-
ax.set_xticklabels(display_names, rotation=45, ha='right')
1468-
ax.legend()
1690+
ax.set_xticklabels(display_names, rotation=45, ha='right', fontsize=14)
1691+
ax.legend(fontsize=14)
14691692
# Add value labels
14701693
def autolabel(rects):
14711694
for rect in rects:
@@ -1474,7 +1697,7 @@ def autolabel(rects):
14741697
xy=(rect.get_x() + rect.get_width()/2, height),
14751698
xytext=(0, 3), # 3 points vertical offset
14761699
textcoords="offset points",
1477-
ha='center', va='bottom')
1700+
ha='center', va='bottom', fontsize=12)
14781701
autolabel(rects1)
14791702
autolabel(rects2)
14801703
fig.tight_layout()

0 commit comments

Comments
 (0)