1+ import torch
2+ import numpy as np
3+ import matplotlib .pyplot as plt
4+ from typing import List , Dict
5+
6+ class AdaptiveWindowAnalyzer :
7+ def __init__ (self ):
8+ self .history = []
9+
10+ def collect_metrics (self , all_metrics : List [Dict ]):
11+ """Collect metrics from all transformer layers"""
12+ batch_data = {
13+ 'T_means' : [m ['T_mean' ] for m in all_metrics ],
14+ 'T_stds' : [m ['T_std' ] for m in all_metrics ],
15+ 'reg_losses' : [m ['reg_loss' ] for m in all_metrics ],
16+ 'adaptive_windows' : [m ['adaptive_windows' ].cpu ().numpy ()
17+ for m in all_metrics ]
18+ }
19+ self .history .append (batch_data )
20+
21+ def plot_layer_comparison (self , save_path = None ):
22+ """Compare adaptive windows across transformer layers"""
23+ if not self .history :
24+ return
25+
26+ num_layers = len (self .history [0 ]['T_means' ])
27+
28+ fig , axes = plt .subplots (2 , 2 , figsize = (15 , 10 ))
29+
30+ # Extract data across all batches
31+ layer_means = [[] for _ in range (num_layers )]
32+ for batch in self .history :
33+ for layer_idx , mean_val in enumerate (batch ['T_means' ]):
34+ layer_means [layer_idx ].append (mean_val )
35+
36+ # Plot 1: Average window size per layer over time
37+ for layer_idx in range (num_layers ):
38+ axes [0 ,0 ].plot (layer_means [layer_idx ],
39+ label = f'Layer { layer_idx + 1 } ' , alpha = 0.8 )
40+ axes [0 ,0 ].set_title ('Average Window Size Over Training' )
41+ axes [0 ,0 ].set_xlabel ('Training Step' )
42+ axes [0 ,0 ].set_ylabel ('Mean T_i' )
43+ axes [0 ,0 ].legend ()
44+ axes [0 ,0 ].grid (True , alpha = 0.3 )
45+
46+ # Plot 2: Final window size distribution per layer
47+ final_means = [layer_means [i ][- 1 ] for i in range (num_layers )]
48+ axes [0 ,1 ].bar (range (1 , num_layers + 1 ), final_means , alpha = 0.7 )
49+ axes [0 ,1 ].set_title ('Final Average Window Size by Layer' )
50+ axes [0 ,1 ].set_xlabel ('Layer' )
51+ axes [0 ,1 ].set_ylabel ('Mean T_i' )
52+ axes [0 ,1 ].grid (True , alpha = 0.3 )
53+
54+ # Plot 3: Regularization loss per layer
55+ layer_reg_losses = [[] for _ in range (num_layers )]
56+ for batch in self .history :
57+ for layer_idx , reg_loss in enumerate (batch ['reg_losses' ]):
58+ layer_reg_losses [layer_idx ].append (reg_loss )
59+
60+ for layer_idx in range (num_layers ):
61+ axes [1 ,0 ].plot (layer_reg_losses [layer_idx ],
62+ label = f'Layer { layer_idx + 1 } ' , alpha = 0.8 )
63+ axes [1 ,0 ].set_title ('Regularization Loss Over Training' )
64+ axes [1 ,0 ].set_xlabel ('Training Step' )
65+ axes [1 ,0 ].set_ylabel ('Reg Loss' )
66+ axes [1 ,0 ].legend ()
67+ axes [1 ,0 ].grid (True , alpha = 0.3 )
68+
69+ # Plot 4: Efficiency metrics
70+ efficiencies = []
71+ for batch in self .history :
72+ batch_eff = [mean / 20.0 for mean in batch ['T_means' ]] # Assuming T_max=20
73+ efficiencies .append (np .mean (batch_eff ))
74+
75+ axes [1 ,1 ].plot (efficiencies , 'g-' , linewidth = 2 )
76+ axes [1 ,1 ].set_title ('Model Efficiency Over Training' )
77+ axes [1 ,1 ].set_xlabel ('Training Step' )
78+ axes [1 ,1 ].set_ylabel ('Efficiency (avg T_i / T_max)' )
79+ axes [1 ,1 ].grid (True , alpha = 0.3 )
80+
81+ plt .tight_layout ()
82+
83+ if save_path :
84+ plt .savefig (save_path , dpi = 300 , bbox_inches = 'tight' )
85+ plt .show ()
86+
87+ # Print summary statistics
88+ print (f"\n 📊 Multi-Layer Adaptive Windows Analysis:" )
89+ print (f" Total training steps: { len (self .history )} " )
90+ print (f" Number of layers: { num_layers } " )
91+ for i in range (num_layers ):
92+ final_mean = layer_means [i ][- 1 ] if layer_means [i ] else 0
93+ print (f" Layer { i + 1 } final avg window: { final_mean :.2f} " )
94+ print (f" Final model efficiency: { efficiencies [- 1 ]* 100 :.1f} %" )
0 commit comments