Skip to content

Commit 8d3144c

Browse files
committed
chore: add SpikingCIFARModel and TetherLM for CIFAR-10 training, include dataset download functionality, and update dependencies in pyproject.toml
1 parent ae7b023 commit 8d3144c

File tree

4 files changed

+142
-2
lines changed

4 files changed

+142
-2
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ input.txt
1414
# Sphinx build artifacts
1515
docs/_build/
1616
docs/tether.*.rst
17-
docs/modules.rst
17+
docs/modules.rst
18+
19+
data/

examples/train_cifar.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import time
5+
from torchvision import datasets, transforms
6+
from torch.utils.data import DataLoader
7+
from tether.nn import LIF
8+
from tether.data import SpikingDatasetWrapper, rate_encoding
9+
from tether.utils.monitor import Monitor #
10+
11+
class SpikingCIFARModel(nn.Module):
12+
def __init__(self, n_steps=10):
13+
super().__init__()
14+
self.n_steps = n_steps
15+
16+
# Define layers
17+
self.conv_layers = nn.Sequential(
18+
nn.Conv2d(3, 32, kernel_size=3, padding=1),
19+
LIF(32 * 32 * 32),
20+
nn.MaxPool2d(2),
21+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
22+
LIF(64 * 16 * 16),
23+
nn.MaxPool2d(2)
24+
)
25+
26+
self.fc_layers = nn.Sequential(
27+
nn.Flatten(),
28+
nn.Linear(64 * 8 * 8, 256),
29+
LIF(256),
30+
nn.Linear(256, 10)
31+
)
32+
33+
def forward(self, x):
34+
# x shape: (Time, Batch, C, H, W)
35+
outputs = []
36+
for t in range(self.n_steps):
37+
x_t = x[t]
38+
feat = self.conv_layers(x_t)
39+
out = self.fc_layers(feat)
40+
outputs.append(out)
41+
return torch.stack(outputs).mean(0)
42+
43+
def evaluate(model, loader, device):
44+
model.eval()
45+
correct = 0
46+
total = 0
47+
with torch.no_grad():
48+
for data, target in loader:
49+
data = data.to(device).transpose(0, 1)
50+
target = target.to(device)
51+
output = model(data)
52+
pred = output.argmax(dim=1)
53+
correct += (pred == target).sum().item()
54+
total += target.size(0)
55+
return 100. * correct / total
56+
57+
def main():
58+
# --- Configuration ---
59+
device = "cuda" if torch.cuda.is_available() else "cpu"
60+
n_steps = 10
61+
batch_size = 64
62+
epochs = 10
63+
lr = 1e-3
64+
65+
# --- Data Preparation ---
66+
transform = transforms.Compose([
67+
transforms.ToTensor(),
68+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69+
])
70+
71+
train_raw = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
72+
test_raw = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
73+
74+
train_ds = SpikingDatasetWrapper(train_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps))
75+
test_ds = SpikingDatasetWrapper(test_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps))
76+
77+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
78+
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
79+
80+
# --- Model & Monitoring ---
81+
model = SpikingCIFARModel(n_steps=n_steps).to(device)
82+
optimizer = optim.AdamW(model.parameters(), lr=lr)
83+
criterion = nn.CrossEntropyLoss()
84+
85+
# Initialize the Tether Monitor
86+
monitor = Monitor(model)
87+
88+
print(f"Starting Tether CIFAR-10 Training on {device}")
89+
print("-" * 60)
90+
91+
# --- Training Loop ---
92+
for epoch in range(epochs):
93+
model.train()
94+
start_time = time.time()
95+
running_loss = 0.0
96+
train_correct = 0
97+
train_total = 0
98+
99+
for data, target in train_loader:
100+
data = data.to(device).transpose(0, 1)
101+
target = target.to(device)
102+
103+
optimizer.zero_grad()
104+
output = model(data)
105+
loss = criterion(output, target)
106+
loss.backward()
107+
optimizer.step()
108+
109+
# Track training metrics
110+
running_loss += loss.item()
111+
pred = output.argmax(dim=1)
112+
train_correct += (pred == target).sum().item()
113+
train_total += target.size(0)
114+
115+
# --- End of Epoch Monitoring ---
116+
epoch_time = time.time() - start_time
117+
avg_loss = running_loss / len(train_loader)
118+
train_acc = 100. * train_correct / train_total
119+
val_acc = evaluate(model, test_loader, device)
120+
121+
# Retrieve firing rates from all LIF layers via Monitor
122+
firing_rates = monitor.get_firing_rates()
123+
# Calculate mean firing rate across the entire model
124+
mean_fr = sum(firing_rates.values()) / len(firing_rates) if firing_rates else 0.0
125+
126+
# Print detailed report once per epoch
127+
print(f"Epoch {epoch+1}/{epochs} | Time: {epoch_time:.1f}s")
128+
print(f" > Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
129+
print(f" > Mean Firing Rate: {mean_fr:.4f} (Sparsity: {(1-mean_fr)*100:.1f}%)")
130+
print("-" * 60)
131+
132+
if __name__ == "__main__":
133+
main()

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ version = "0.1.1"
44
description = "Triton-powered framework for training and deploying Spiking Transformers."
55
readme = "README.md"
66
requires-python = ">=3.10"
7-
dependencies = ["numpy", "torch", "triton"]
7+
dependencies = [
8+
"numpy",
9+
"torch",
10+
"torchvision>=0.24.1",
11+
"triton",
12+
]
813

914

1015
[build-system]

0 commit comments

Comments
 (0)