1+ import torch
2+ import torch .nn as nn
3+ from ..models .full_model import AdaptiveSpikingTransformer
4+
5+ class AdaptiveSpikingTrainer :
6+ def __init__ (self , model , optimizer , device = 'cuda' ):
7+ self .model = model
8+ self .optimizer = optimizer
9+ self .device = device
10+ self .step_count = 0
11+
12+ def train_step (self , batch ):
13+ self .model .train ()
14+
15+ inputs , targets = batch
16+ inputs , targets = inputs .to (self .device ), targets .to (self .device )
17+
18+ # Forward pass
19+ logits , all_metrics = self .model (inputs )
20+
21+ # Task loss (e.g., cross-entropy)
22+ task_loss = nn .CrossEntropyLoss ()(
23+ logits .view (- 1 , logits .size (- 1 )),
24+ targets .view (- 1 )
25+ )
26+
27+ # 🔥 Collect regularization losses from all layers
28+ reg_loss = sum ([metrics ['reg_loss' ] for metrics in all_metrics ])
29+
30+ # Total loss
31+ total_loss = task_loss + reg_loss
32+
33+ # Backward pass
34+ self .optimizer .zero_grad ()
35+ total_loss .backward ()
36+ self .optimizer .step ()
37+
38+ # 🔥 Log adaptive window statistics
39+ avg_T_mean = sum ([m ['T_mean' ] for m in all_metrics ]) / len (all_metrics )
40+
41+ self .step_count += 1
42+
43+ return {
44+ 'total_loss' : total_loss .item (),
45+ 'task_loss' : task_loss .item (),
46+ 'reg_loss' : reg_loss .item (),
47+ 'avg_window_size' : avg_T_mean ,
48+ 'step' : self .step_count
49+ }
0 commit comments