Skip to content

Commit 6126f2c

Browse files
Create trainer.py
1 parent b2d2a02 commit 6126f2c

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/training.py/trainer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import torch.nn as nn
3+
from ..models.full_model import AdaptiveSpikingTransformer
4+
5+
class AdaptiveSpikingTrainer:
6+
def __init__(self, model, optimizer, device='cuda'):
7+
self.model = model
8+
self.optimizer = optimizer
9+
self.device = device
10+
self.step_count = 0
11+
12+
def train_step(self, batch):
13+
self.model.train()
14+
15+
inputs, targets = batch
16+
inputs, targets = inputs.to(self.device), targets.to(self.device)
17+
18+
# Forward pass
19+
logits, all_metrics = self.model(inputs)
20+
21+
# Task loss (e.g., cross-entropy)
22+
task_loss = nn.CrossEntropyLoss()(
23+
logits.view(-1, logits.size(-1)),
24+
targets.view(-1)
25+
)
26+
27+
# 🔥 Collect regularization losses from all layers
28+
reg_loss = sum([metrics['reg_loss'] for metrics in all_metrics])
29+
30+
# Total loss
31+
total_loss = task_loss + reg_loss
32+
33+
# Backward pass
34+
self.optimizer.zero_grad()
35+
total_loss.backward()
36+
self.optimizer.step()
37+
38+
# 🔥 Log adaptive window statistics
39+
avg_T_mean = sum([m['T_mean'] for m in all_metrics]) / len(all_metrics)
40+
41+
self.step_count += 1
42+
43+
return {
44+
'total_loss': total_loss.item(),
45+
'task_loss': task_loss.item(),
46+
'reg_loss': reg_loss.item(),
47+
'avg_window_size': avg_T_mean,
48+
'step': self.step_count
49+
}

0 commit comments

Comments
 (0)