Skip to content

Commit 9d345e8

Browse files
Create analysis.py
1 parent 6126f2c commit 9d345e8

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

src/utils/analysis.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from typing import List, Dict
5+
6+
class AdaptiveWindowAnalyzer:
7+
def __init__(self):
8+
self.history = []
9+
10+
def collect_metrics(self, all_metrics: List[Dict]):
11+
"""Collect metrics from all transformer layers"""
12+
batch_data = {
13+
'T_means': [m['T_mean'] for m in all_metrics],
14+
'T_stds': [m['T_std'] for m in all_metrics],
15+
'reg_losses': [m['reg_loss'] for m in all_metrics],
16+
'adaptive_windows': [m['adaptive_windows'].cpu().numpy()
17+
for m in all_metrics]
18+
}
19+
self.history.append(batch_data)
20+
21+
def plot_layer_comparison(self, save_path=None):
22+
"""Compare adaptive windows across transformer layers"""
23+
if not self.history:
24+
return
25+
26+
num_layers = len(self.history[0]['T_means'])
27+
28+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
29+
30+
# Extract data across all batches
31+
layer_means = [[] for _ in range(num_layers)]
32+
for batch in self.history:
33+
for layer_idx, mean_val in enumerate(batch['T_means']):
34+
layer_means[layer_idx].append(mean_val)
35+
36+
# Plot 1: Average window size per layer over time
37+
for layer_idx in range(num_layers):
38+
axes[0,0].plot(layer_means[layer_idx],
39+
label=f'Layer {layer_idx+1}', alpha=0.8)
40+
axes[0,0].set_title('Average Window Size Over Training')
41+
axes[0,0].set_xlabel('Training Step')
42+
axes[0,0].set_ylabel('Mean T_i')
43+
axes[0,0].legend()
44+
axes[0,0].grid(True, alpha=0.3)
45+
46+
# Plot 2: Final window size distribution per layer
47+
final_means = [layer_means[i][-1] for i in range(num_layers)]
48+
axes[0,1].bar(range(1, num_layers+1), final_means, alpha=0.7)
49+
axes[0,1].set_title('Final Average Window Size by Layer')
50+
axes[0,1].set_xlabel('Layer')
51+
axes[0,1].set_ylabel('Mean T_i')
52+
axes[0,1].grid(True, alpha=0.3)
53+
54+
# Plot 3: Regularization loss per layer
55+
layer_reg_losses = [[] for _ in range(num_layers)]
56+
for batch in self.history:
57+
for layer_idx, reg_loss in enumerate(batch['reg_losses']):
58+
layer_reg_losses[layer_idx].append(reg_loss)
59+
60+
for layer_idx in range(num_layers):
61+
axes[1,0].plot(layer_reg_losses[layer_idx],
62+
label=f'Layer {layer_idx+1}', alpha=0.8)
63+
axes[1,0].set_title('Regularization Loss Over Training')
64+
axes[1,0].set_xlabel('Training Step')
65+
axes[1,0].set_ylabel('Reg Loss')
66+
axes[1,0].legend()
67+
axes[1,0].grid(True, alpha=0.3)
68+
69+
# Plot 4: Efficiency metrics
70+
efficiencies = []
71+
for batch in self.history:
72+
batch_eff = [mean/20.0 for mean in batch['T_means']] # Assuming T_max=20
73+
efficiencies.append(np.mean(batch_eff))
74+
75+
axes[1,1].plot(efficiencies, 'g-', linewidth=2)
76+
axes[1,1].set_title('Model Efficiency Over Training')
77+
axes[1,1].set_xlabel('Training Step')
78+
axes[1,1].set_ylabel('Efficiency (avg T_i / T_max)')
79+
axes[1,1].grid(True, alpha=0.3)
80+
81+
plt.tight_layout()
82+
83+
if save_path:
84+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
85+
plt.show()
86+
87+
# Print summary statistics
88+
print(f"\n📊 Multi-Layer Adaptive Windows Analysis:")
89+
print(f" Total training steps: {len(self.history)}")
90+
print(f" Number of layers: {num_layers}")
91+
for i in range(num_layers):
92+
final_mean = layer_means[i][-1] if layer_means[i] else 0
93+
print(f" Layer {i+1} final avg window: {final_mean:.2f}")
94+
print(f" Final model efficiency: {efficiencies[-1]*100:.1f}%")

0 commit comments

Comments
 (0)