1+ import torch
2+ import torch .nn as nn
3+ import torch .optim as optim
4+ import time
5+ from torchvision import datasets , transforms
6+ from torch .utils .data import DataLoader
7+ from tether .nn import LIF
8+ from tether .data import SpikingDatasetWrapper , rate_encoding
9+ from tether .utils .monitor import Monitor #
10+
11+ class SpikingCIFARModel (nn .Module ):
12+ def __init__ (self , n_steps = 10 ):
13+ super ().__init__ ()
14+ self .n_steps = n_steps
15+
16+ # Define layers
17+ self .conv_layers = nn .Sequential (
18+ nn .Conv2d (3 , 32 , kernel_size = 3 , padding = 1 ),
19+ LIF (32 * 32 * 32 ),
20+ nn .MaxPool2d (2 ),
21+ nn .Conv2d (32 , 64 , kernel_size = 3 , padding = 1 ),
22+ LIF (64 * 16 * 16 ),
23+ nn .MaxPool2d (2 )
24+ )
25+
26+ self .fc_layers = nn .Sequential (
27+ nn .Flatten (),
28+ nn .Linear (64 * 8 * 8 , 256 ),
29+ LIF (256 ),
30+ nn .Linear (256 , 10 )
31+ )
32+
33+ def forward (self , x ):
34+ # x shape: (Time, Batch, C, H, W)
35+ outputs = []
36+ for t in range (self .n_steps ):
37+ x_t = x [t ]
38+ feat = self .conv_layers (x_t )
39+ out = self .fc_layers (feat )
40+ outputs .append (out )
41+ return torch .stack (outputs ).mean (0 )
42+
43+ def evaluate (model , loader , device ):
44+ model .eval ()
45+ correct = 0
46+ total = 0
47+ with torch .no_grad ():
48+ for data , target in loader :
49+ data = data .to (device ).transpose (0 , 1 )
50+ target = target .to (device )
51+ output = model (data )
52+ pred = output .argmax (dim = 1 )
53+ correct += (pred == target ).sum ().item ()
54+ total += target .size (0 )
55+ return 100. * correct / total
56+
57+ def main ():
58+ # --- Configuration ---
59+ device = "cuda" if torch .cuda .is_available () else "cpu"
60+ n_steps = 10
61+ batch_size = 64
62+ epochs = 10
63+ lr = 1e-3
64+
65+ # --- Data Preparation ---
66+ transform = transforms .Compose ([
67+ transforms .ToTensor (),
68+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
69+ ])
70+
71+ train_raw = datasets .CIFAR10 (root = "./data" , train = True , download = True , transform = transform )
72+ test_raw = datasets .CIFAR10 (root = "./data" , train = False , download = True , transform = transform )
73+
74+ train_ds = SpikingDatasetWrapper (train_raw , encode_fn = lambda x : rate_encoding (x , n_steps = n_steps ))
75+ test_ds = SpikingDatasetWrapper (test_raw , encode_fn = lambda x : rate_encoding (x , n_steps = n_steps ))
76+
77+ train_loader = DataLoader (train_ds , batch_size = batch_size , shuffle = True , num_workers = 4 )
78+ test_loader = DataLoader (test_ds , batch_size = batch_size , shuffle = False , num_workers = 4 )
79+
80+ # --- Model & Monitoring ---
81+ model = SpikingCIFARModel (n_steps = n_steps ).to (device )
82+ optimizer = optim .AdamW (model .parameters (), lr = lr )
83+ criterion = nn .CrossEntropyLoss ()
84+
85+ # Initialize the Tether Monitor
86+ monitor = Monitor (model )
87+
88+ print (f"Starting Tether CIFAR-10 Training on { device } " )
89+ print ("-" * 60 )
90+
91+ # --- Training Loop ---
92+ for epoch in range (epochs ):
93+ model .train ()
94+ start_time = time .time ()
95+ running_loss = 0.0
96+ train_correct = 0
97+ train_total = 0
98+
99+ for data , target in train_loader :
100+ data = data .to (device ).transpose (0 , 1 )
101+ target = target .to (device )
102+
103+ optimizer .zero_grad ()
104+ output = model (data )
105+ loss = criterion (output , target )
106+ loss .backward ()
107+ optimizer .step ()
108+
109+ # Track training metrics
110+ running_loss += loss .item ()
111+ pred = output .argmax (dim = 1 )
112+ train_correct += (pred == target ).sum ().item ()
113+ train_total += target .size (0 )
114+
115+ # --- End of Epoch Monitoring ---
116+ epoch_time = time .time () - start_time
117+ avg_loss = running_loss / len (train_loader )
118+ train_acc = 100. * train_correct / train_total
119+ val_acc = evaluate (model , test_loader , device )
120+
121+ # Retrieve firing rates from all LIF layers via Monitor
122+ firing_rates = monitor .get_firing_rates ()
123+ # Calculate mean firing rate across the entire model
124+ mean_fr = sum (firing_rates .values ()) / len (firing_rates ) if firing_rates else 0.0
125+
126+ # Print detailed report once per epoch
127+ print (f"Epoch { epoch + 1 } /{ epochs } | Time: { epoch_time :.1f} s" )
128+ print (f" > Loss: { avg_loss :.4f} | Train Acc: { train_acc :.2f} % | Val Acc: { val_acc :.2f} %" )
129+ print (f" > Mean Firing Rate: { mean_fr :.4f} (Sparsity: { (1 - mean_fr )* 100 :.1f} %)" )
130+ print ("-" * 60 )
131+
132+ if __name__ == "__main__" :
133+ main ()
0 commit comments