4545# ---------------------------------------------------------- PLOTTING SETTINGS ------------------------------------------------------ #
4646plt .style .use ("seaborn-v0_8-whitegrid" ) #
4747plt .rcParams .update ({ #
48- "font.size" : 14 , # Base font size for all text in the plot. #
49- "axes.titlesize" : 16 , # Title size for axes. #
50- "axes.labelsize" : 16 , # Axis labels size. #
51- "xtick.labelsize" : 16 , # X-axis tick labels size. #
52- "ytick.labelsize" : 16 , # Y-axis tick labels size. #
53- "legend.fontsize" : 18 , # Legend font size for all text in the legend. #
54- "figure.titlesize" : 20 , # Overall figure title size for all text in the figure. #
48+ "font.size" : 16 , # Base font size for all text in the plot. #
49+ "axes.titlesize" : 18 , # Title size for axes. #
50+ "axes.labelsize" : 18 , # Axis labels size. #
51+ "xtick.labelsize" : 18 , # X-axis tick labels size. #
52+ "ytick.labelsize" : 18 , # Y-axis tick labels size. #
53+ "legend.fontsize" : 20 , # Legend font size for all text in the legend. #
54+ "figure.titlesize" : 24 , # Overall figure title size for all text in the figure. #
5555}) #
5656sns .set_theme (style = "whitegrid" , context = "paper" ) # Set seaborn theme for additional aesthetics and context. #
5757plt .rcParams ["figure.figsize" ] = (6 , 4 ) # Set default figure size for all plots to 6x4. #
@@ -797,18 +797,19 @@ def visualize_prior_vs_posterior(self, output_dir, experiment_name, num_samples=
797797 # For 'flat' prior or any other prior not in prior_metrics, use a default value
798798 prior_std = 1.0 # Default value
799799 print (f"Using default prior std={ prior_std } for { self .prior_type } prior" )
800- # Calculate errors to identify good and poor examples
800+ # Calculate errors to identify examples
801801 if self .prior_type in prior_test :
802802 prior_mean = prior_test [self .prior_type ]
803803 prior_errors = np .abs (prior_mean - y_test )
804804 posterior_errors = np .abs (y_pred_mean - y_test )
805805 # Calculate improvement (positive means posterior is better)
806806 improvements = prior_errors - posterior_errors
807- # Find one good example (high improvement) and one poor example (low improvement)
807+ # Find a good example (high improvement)
808808 good_idx = np .argmax (improvements )
809- poor_idx = np .argmin (improvements )
809+ # Look for cases where posterior uncertainty is high
810+ worse_idx = np .argmax (y_pred_std )
810811 # Use these two specific indices
811- indices = [good_idx , poor_idx ]
812+ indices = [good_idx , worse_idx ]
812813 else :
813814 # If we don't have prior data, just pick first and last examples
814815 indices = [0 , len (y_test )- 1 ]
@@ -847,7 +848,7 @@ def visualize_prior_vs_posterior(self, output_dir, experiment_name, num_samples=
847848 if i == 0 :
848849 ax .set_title (f'Good Example: Prior vs Posterior (True Angle: { true_angle :.2f} °)' )
849850 else :
850- ax .set_title (f'Poor Example: Prior vs Posterior (True Angle: { true_angle :.2f} °)' )
851+ ax .set_title (f'High Uncertainty Example: Prior vs Posterior (True Angle: { true_angle :.2f} °)' )
851852 ax .grid (True , alpha = 0.3 )
852853 ax .legend ()
853854 # Annotate statistics
@@ -907,12 +908,12 @@ def plot_posterior_predictive(self, output_dir, experiment_name):
907908 label = '90% Credible Interval' )
908909 plt .fill_between (range (len (y_test )), y_pred_25 , y_pred_75 , alpha = 0.5 , color = 'blue' ,
909910 label = '50% Credible Interval' )
910- plt .plot (range (len (y_test )), y_pred_50 , 'orange' , linewidth = 3 , label = 'Median Prediction' )
911- plt .plot (range (len (y_test )), y_test_sorted , 'ro' , makersize = 10 , label = 'True Angles' )
912- plt .xlabel ('Test Point Index (sorted by true angle)' )
913- plt .ylabel ('Angle (degrees)' )
914- plt .title ('Posterior Predictive Distribution' )
915- plt .legend ()
911+ plt .plot (range (len (y_test )), y_pred_50 , 'orange' , linewidth = 2.5 , label = 'Median Prediction' )
912+ plt .plot (range (len (y_test )), y_test_sorted , 'ro' , markersize = 8 , label = 'True Angles' )
913+ plt .xlabel ('Test Point Index (sorted by true angle)' , fontsize = 16 )
914+ plt .ylabel ('Angle (degrees)' , fontsize = 16 )
915+ plt .title ('Posterior Predictive Distribution' , fontsize = 20 )
916+ plt .legend (fontsize = 16 )
916917 plt .grid (True , alpha = 0.3 )
917918 plt .tight_layout ()
918919 plt .savefig (os .path .join (vis_dir , "posterior_predictive.png" ), dpi = 300 , bbox_inches = 'tight' )
@@ -1313,6 +1314,225 @@ def _extract_features_for_sample(self, phasor1, phasor2, rssi1, rssi2, D, W, wav
13131314 features = np .append (features , D )
13141315 return features
13151316
1317+ def analyze_svi_convergence_metrics (models_dict , results_dir ):
1318+ """
1319+ Analyze SVI convergence metrics across different models/priors and generate a tabular report.
1320+
1321+ Parameters:
1322+ - models_dict [dict] : Dictionary of trained BayesianAoARegressor models
1323+ - results_dir [str] : Directory to save the convergence analysis results
1324+
1325+ Returns:
1326+ - convergence_metrics [dict]: Dictionary with convergence metrics by model
1327+ """
1328+ print ("\n === ANALYZING SVI CONVERGENCE ACROSS PRIORS ===" )
1329+
1330+ # Output directory
1331+ os .makedirs (results_dir , exist_ok = True )
1332+
1333+ # Store convergence metrics
1334+ convergence_metrics = {}
1335+
1336+ # Define convergence criteria
1337+ def detect_convergence (losses , window_size = 100 , threshold_pct = 0.5 ):
1338+ """Detect when loss has converged (stabilized)"""
1339+ if len (losses ) <= window_size :
1340+ return len (losses ) - 1 # Not enough data, return last index
1341+
1342+ # Calculate rolling standard deviation as percentage of mean
1343+ rolling_std_pct = []
1344+ for i in range (window_size , len (losses )):
1345+ window = losses [i - window_size :i ]
1346+ std_pct = (np .std (window ) / abs (np .mean (window ))) * 100
1347+ rolling_std_pct .append (std_pct )
1348+
1349+ # Find first point where std_pct is below threshold
1350+ for i , std_pct in enumerate (rolling_std_pct ):
1351+ if std_pct < threshold_pct :
1352+ return i + window_size # Add window_size to get the actual index
1353+
1354+ return len (losses ) - 1 # No convergence detected, return last index
1355+
1356+ # Define function to fit exponential decay
1357+ def fit_convergence_rate (iterations , losses ):
1358+ """Fit exponential decay model to estimate convergence rate"""
1359+ from scipy .optimize import curve_fit
1360+
1361+ def exp_decay (x , a , b , c ):
1362+ return a * np .exp (- b * x ) + c
1363+
1364+ try :
1365+ # Normalize iterations to [0, 1] for numerical stability
1366+ x_norm = iterations / np .max (iterations )
1367+ params , _ = curve_fit (exp_decay , x_norm , losses ,
1368+ p0 = [losses [0 ]- losses [- 1 ], 5 , losses [- 1 ]],
1369+ bounds = ([0 , 0 , - np .inf ], [np .inf , 100 , np .inf ]))
1370+
1371+ convergence_rate = params [1 ]
1372+ return convergence_rate
1373+ except Exception as e :
1374+ print (f"Could not fit convergence rate: { e } " )
1375+ return None
1376+
1377+ # Analyze each model
1378+ for model_name , model in models_dict .items ():
1379+ if model .train_summary is None or 'losses' not in model .train_summary :
1380+ print (f"Skipping { model_name } - no loss data available" )
1381+ continue
1382+
1383+ losses = np .array (model .train_summary ['losses' ])
1384+ iterations = np .arange (1 , len (losses ) + 1 )
1385+
1386+ # Calculate key convergence metrics
1387+
1388+ # 1. Iterations to convergence (using our detection function)
1389+ iter_to_converge = detect_convergence (losses )
1390+
1391+ # 2. Loss reduction (initial vs. final)
1392+ initial_loss = losses [0 ]
1393+ final_loss = losses [- 1 ]
1394+ abs_reduction = initial_loss - final_loss
1395+ pct_reduction = (abs_reduction / abs (initial_loss )) * 100 if initial_loss != 0 else 0
1396+
1397+ # 3. Convergence rate (from exponential fit)
1398+ conv_rate = fit_convergence_rate (iterations , losses )
1399+
1400+ # 4. Stability of final 10% of training
1401+ final_window = losses [- int (0.1 * len (losses )):]
1402+ final_stability = (np .std (final_window ) / abs (np .mean (final_window ))) * 100
1403+
1404+ # 5. Early-stage (0-20%) vs late-stage (80-100%) loss reduction rate
1405+ early_loss_drop = losses [0 ] - losses [int (0.2 * len (losses ))]
1406+ early_rate = early_loss_drop / (0.2 * len (losses ))
1407+
1408+ late_window_start = int (0.8 * len (losses ))
1409+ late_loss_drop = losses [late_window_start ] - losses [- 1 ]
1410+ late_rate = late_loss_drop / (len (losses ) - late_window_start )
1411+
1412+ # Store metrics
1413+ convergence_metrics [model_name ] = {
1414+ 'prior_type' : model .prior_type ,
1415+ 'feature_mode' : model .feature_mode ,
1416+ 'iterations' : len (losses ),
1417+ 'iter_to_converge' : iter_to_converge ,
1418+ 'pct_to_converge' : (iter_to_converge / len (losses )) * 100 ,
1419+ 'initial_loss' : initial_loss ,
1420+ 'final_loss' : final_loss ,
1421+ 'abs_reduction' : abs_reduction ,
1422+ 'pct_reduction' : pct_reduction ,
1423+ 'convergence_rate' : conv_rate ,
1424+ 'final_stability' : final_stability ,
1425+ 'early_reduction_rate' : early_rate ,
1426+ 'late_reduction_rate' : late_rate
1427+ }
1428+
1429+ # Create a tabular summary for paper
1430+ columns = ['prior_type' , 'feature_mode' , 'iter_to_converge' , 'pct_to_converge' ,
1431+ 'pct_reduction' , 'convergence_rate' , 'final_stability' ]
1432+
1433+ # 1. Group by prior type
1434+ prior_groups = {}
1435+ for name , metrics in convergence_metrics .items ():
1436+ prior = metrics ['prior_type' ]
1437+ if prior not in prior_groups :
1438+ prior_groups [prior ] = []
1439+ prior_groups [prior ].append ((name , metrics ))
1440+
1441+ # Write summary tables
1442+ with open (os .path .join (results_dir , "svi_convergence_summary.txt" ), 'w' ) as f :
1443+ f .write ("SVI CONVERGENCE ANALYSIS BY PRIOR TYPE\n " )
1444+ f .write ("=====================================\n \n " )
1445+
1446+ # Table 1: Iterations to convergence by prior type (averaged across feature modes)
1447+ f .write ("TABLE 1: CONVERGENCE METRICS BY PRIOR TYPE\n " )
1448+ f .write ("Prior Type | Iterations to Converge | % of Total | Convergence Rate | Loss Reduction (%)\n " )
1449+ f .write ("----------|----------------------|-----------|-----------------|------------------\n " )
1450+
1451+ for prior in sorted (prior_groups .keys ()):
1452+ metrics_list = [m for _ , m in prior_groups [prior ]]
1453+ avg_iter = np .mean ([m ['iter_to_converge' ] for m in metrics_list ])
1454+ avg_pct = np .mean ([m ['pct_to_converge' ] for m in metrics_list ])
1455+ avg_rate = np .mean ([m ['convergence_rate' ] for m in metrics_list if m ['convergence_rate' ] is not None ])
1456+ avg_reduction = np .mean ([m ['pct_reduction' ] for m in metrics_list ])
1457+
1458+ f .write (f"{ prior .upper ():10} | { avg_iter :.1f} | { avg_pct :.1f} % | { avg_rate :.3f} | { avg_reduction :.1f} %\n " )
1459+
1460+ f .write ("\n \n " )
1461+
1462+ # Table 2: Detailed metrics for each configuration
1463+ f .write ("TABLE 2: DETAILED CONVERGENCE METRICS BY MODEL CONFIGURATION\n " )
1464+ f .write ("Model | Prior | Features | Iter to Conv | % of Total | Conv Rate | Reduction | Stability\n " )
1465+ f .write ("------|-------|----------|-------------|-----------|----------|-----------|----------\n " )
1466+
1467+ for name , metrics in sorted (convergence_metrics .items ()):
1468+ prior = metrics ['prior_type' ].upper ()
1469+ features = metrics ['feature_mode' ]
1470+ iter_conv = metrics ['iter_to_converge' ]
1471+ pct_conv = metrics ['pct_to_converge' ]
1472+ conv_rate = metrics ['convergence_rate' ] if metrics ['convergence_rate' ] is not None else "N/A"
1473+ reduction = metrics ['pct_reduction' ]
1474+ stability = metrics ['final_stability' ]
1475+
1476+ if isinstance (conv_rate , float ):
1477+ conv_rate_str = f"{ conv_rate :.3f} "
1478+ else :
1479+ conv_rate_str = "N/A"
1480+
1481+ f .write (f"{ name :6} | { prior :5} | { features :8} | { iter_conv :11.0f} | { pct_conv :9.1f} % | "
1482+ f"{ conv_rate_str :8} | { reduction :8.1f} % | { stability :8.2f} %\n " )
1483+
1484+ f .write ("\n \n " )
1485+
1486+ # Interpretation and recommendations
1487+ f .write ("INTERPRETATION OF METRICS:\n " )
1488+ f .write ("-------------------------\n " )
1489+ f .write ("- Iterations to Converge: Lower is better, faster convergence\n " )
1490+ f .write ("- % of Total: Percentage of total training iterations needed to converge\n " )
1491+ f .write ("- Convergence Rate: Higher is better, faster exponential decay of loss\n " )
1492+ f .write ("- Loss Reduction: Higher is better, more improvement from initial loss\n " )
1493+ f .write ("- Stability: Lower is better, less variation in final 10% of training\n \n " )
1494+
1495+ # Summary of findings
1496+ f .write ("SUMMARY OF FINDINGS:\n " )
1497+ f .write ("------------------\n " )
1498+
1499+ # Find fastest converging prior
1500+ avg_conv_by_prior = {}
1501+ for prior , models in prior_groups .items ():
1502+ avg_conv_by_prior [prior ] = np .mean ([m ['pct_to_converge' ] for _ , m in models ])
1503+
1504+ fastest_prior = min (avg_conv_by_prior .items (), key = lambda x : x [1 ])[0 ]
1505+ slowest_prior = max (avg_conv_by_prior .items (), key = lambda x : x [1 ])[0 ]
1506+
1507+ f .write (f"1. Fastest converging prior: { fastest_prior .upper ()} "
1508+ f"({ avg_conv_by_prior [fastest_prior ]:.1f} % of iterations)\n " )
1509+ f .write (f"2. Slowest converging prior: { slowest_prior .upper ()} "
1510+ f"({ avg_conv_by_prior [slowest_prior ]:.1f} % of iterations)\n " )
1511+
1512+ # Prior with highest convergence rate
1513+ avg_rate_by_prior = {}
1514+ for prior , models in prior_groups .items ():
1515+ rates = [m ['convergence_rate' ] for _ , m in models if m ['convergence_rate' ] is not None ]
1516+ if rates :
1517+ avg_rate_by_prior [prior ] = np .mean (rates )
1518+
1519+ if avg_rate_by_prior :
1520+ highest_rate_prior = max (avg_rate_by_prior .items (), key = lambda x : x [1 ])[0 ]
1521+ f .write (f"3. Prior with highest convergence rate: { highest_rate_prior .upper ()} "
1522+ f"(rate = { avg_rate_by_prior [highest_rate_prior ]:.3f} )\n " )
1523+
1524+ # Prior with most stable convergence
1525+ avg_stability_by_prior = {}
1526+ for prior , models in prior_groups .items ():
1527+ avg_stability_by_prior [prior ] = np .mean ([m ['final_stability' ] for _ , m in models ])
1528+
1529+ most_stable_prior = min (avg_stability_by_prior .items (), key = lambda x : x [1 ])[0 ]
1530+ f .write (f"4. Most stable convergence: { most_stable_prior .upper ()} "
1531+ f"(final stability = { avg_stability_by_prior [most_stable_prior ]:.2f} %)\n " )
1532+
1533+ print (f"SVI convergence analysis completed and saved to: { os .path .join (results_dir , 'svi_convergence_summary.txt' )} " )
1534+ return convergence_metrics
1535+
13161536def train_bayesian_models (data_manager , results_dir , num_epochs = 10000 ):
13171537 """
13181538 Train multiple Bayesian AoA regression models with different priors and feature sets.
@@ -1347,6 +1567,9 @@ def train_bayesian_models(data_manager, results_dir, num_epochs=10000):
13471567 # Dictionary to store results
13481568 models = {}
13491569 results = {}
1570+ # Create convergence analysis directory
1571+ convergence_dir = os .path .join (results_dir , "svi_convergence_analysis" )
1572+ os .makedirs (convergence_dir , exist_ok = True )
13501573 # Train models for each configuration
13511574 for config in configs :
13521575 print (f"\n --- Training Bayesian model with { config ['prior' ]} prior, { config ['features' ]} features ---" )
@@ -1363,18 +1586,18 @@ def train_bayesian_models(data_manager, results_dir, num_epochs=10000):
13631586 results [config ['name' ]] = train_results
13641587 # Visualize results
13651588 model .visualize_results (results_dir , config ['name' ])
1366- # Generate new visualizations
13671589 model .render_model_and_guide (results_dir , config ['name' ])
13681590 model .plot_posterior_predictive (results_dir , config ['name' ])
13691591 model .plot_uncertainty_calibration (results_dir , config ['name' ])
1370- # Visualize prior vs posterior
13711592 models [config ['name' ]].visualize_prior_vs_posterior (results_dir , config ['name' ])
13721593 model .visualize_weight_distributions (results_dir , config ['name' ])
13731594 model .analyze_with_posterior_weights (data_manager , results_dir , config ['name' ])
13741595 print (f"Completed training { config ['name' ]} " )
1596+ # Perofrm SVI convergence analysis
1597+ convergence_metrics = analyze_svi_convergence_metrics (models , convergence_dir )
13751598 # Create comparison visualizations
13761599 compare_bayesian_models (models , results , results_dir )
1377- return {"models" : models , "results" : results }
1600+ return {"models" : models , "results" : results , "convergence" : convergence_metrics }
13781601
13791602def compare_bayesian_models (models , results , output_dir ):
13801603 """
@@ -1461,11 +1684,11 @@ def compare_bayesian_models(models, results, output_dir):
14611684 rects1 = ax .bar (x - width / 2 , sorted_maes , width , label = 'MAE' )
14621685 rects2 = ax .bar (x + width / 2 , sorted_rmses , width , label = 'RMSE' )
14631686 # Add labels and title
1464- ax .set_ylabel ('Error (degrees)' )
1465- ax .set_title ('Bayesian AoA Model Performance Comparison' )
1687+ ax .set_ylabel ('Error (degrees)' , fontsize = 16 )
1688+ ax .set_title ('Bayesian AoA Model Performance Comparison' , fontsize = 20 )
14661689 ax .set_xticks (x )
1467- ax .set_xticklabels (display_names , rotation = 45 , ha = 'right' )
1468- ax .legend ()
1690+ ax .set_xticklabels (display_names , rotation = 45 , ha = 'right' , fontsize = 14 )
1691+ ax .legend (fontsize = 14 )
14691692 # Add value labels
14701693 def autolabel (rects ):
14711694 for rect in rects :
@@ -1474,7 +1697,7 @@ def autolabel(rects):
14741697 xy = (rect .get_x () + rect .get_width ()/ 2 , height ),
14751698 xytext = (0 , 3 ), # 3 points vertical offset
14761699 textcoords = "offset points" ,
1477- ha = 'center' , va = 'bottom' )
1700+ ha = 'center' , va = 'bottom' , fontsize = 12 )
14781701 autolabel (rects1 )
14791702 autolabel (rects2 )
14801703 fig .tight_layout ()
0 commit comments