Skip to content

Commit 45ea2b4

Browse files
committed
chore: refactor PLIF and LIF modules for improved readability and performance
- Enhanced the PLIF and LIF classes by restructuring the initialization parameters for better clarity. - Updated the forward methods in both classes to improve code readability and maintainability. - Added support for vectorized decay and threshold parameters in PLIF. - Improved the handling of surrogate gradients in both LIF and PLIF. - Refactored the attention mechanism in SpikingSelfAttention to streamline operations. - Updated the Monitor utility to enhance voltage trace monitoring capabilities. - Added comprehensive tests for new features and ensured backward compatibility. - Cleaned up code formatting across multiple files for consistency.
1 parent 1b92921 commit 45ea2b4

27 files changed

+797
-666
lines changed

.github/workflows/docs.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ jobs:
3838
run: |
3939
uv sync --extra docs
4040
41+
- name: Ruff Check
42+
run: |
43+
uv run ruff check .
44+
4145
- name: Build Sphinx Documentation
4246
run: |
4347
# Generate RST files if they are not already committed

benchmarks/benchmark_lif.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
2-
import triton
32
import time
43
import sys
54
import os
65

76
# Add src to path so we can import tether
8-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
7+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
98

109
from tether.functional.lif import LIFSubFunction
1110

11+
1212
def lif_pytorch(x_seq, v_init, decay, threshold):
1313
"""
1414
PyTorch reference implementation of LIF.
@@ -31,25 +31,26 @@ def lif_pytorch(x_seq, v_init, decay, threshold):
3131
"""
3232
# x_seq: (Time, Neurons)
3333
# v_init: (Neurons)
34-
34+
3535
n_steps, n_neurons = x_seq.shape
3636
v = v_init.clone()
3737
spikes_list = []
38-
38+
3939
# We use a simple loop as LIF is recurrent
4040
# This simulates "Vanilla PyTorch" without custom CUDA kernels or JIT
4141
for t in range(n_steps):
4242
x = x_seq[t]
4343
v = v * decay + x
44-
44+
4545
spike = (v >= threshold).float()
4646
spikes_list.append(spike)
47-
47+
4848
# Hard reset
4949
v = v * (1.0 - spike)
50-
50+
5151
return torch.stack(spikes_list)
5252

53+
5354
def benchmark():
5455
"""
5556
Benchmark PyTorch vs Triton implementations of LIF.
@@ -64,10 +65,10 @@ def benchmark():
6465
# Dimensions
6566
# Simulating a reasonable layer size for a Transformer
6667
batch_size = 32
67-
seq_len = 2048 # Longer context to emphasize the loop overhead vs kernel
68-
dim = 768 # Standard BERT-base dimension
68+
seq_len = 2048 # Longer context to emphasize the loop overhead vs kernel
69+
dim = 768 # Standard BERT-base dimension
6970
n_neurons = batch_size * dim
70-
71+
7172
# Inputs
7273
x_seq = torch.randn(seq_len, n_neurons, device=device)
7374
v_init = torch.zeros(n_neurons, device=device)
@@ -81,7 +82,7 @@ def benchmark():
8182
with torch.no_grad():
8283
_ = lif_pytorch(x_seq, v_init, decay, threshold)
8384
LIFSubFunction.apply(x_seq, v_init, decay, threshold, alpha, 0)
84-
85+
8586
# Benchmark PyTorch
8687
torch.cuda.synchronize()
8788
start_time = time.time()
@@ -103,8 +104,9 @@ def benchmark():
103104
torch.cuda.synchronize()
104105
triton_time = (time.time() - start_time) / iterations
105106
print(f"Triton Time: {triton_time * 1000:.3f} ms")
106-
107+
107108
print(f"Speedup: {pytorch_time / triton_time:.2f}x")
108109

110+
109111
if __name__ == "__main__":
110112
benchmark()

benchmarks/benchmark_plif.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import torch
2-
import triton
32
import time
43
import sys
54
import os
65

76
# Add src to path
8-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
7+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
98

109
from tether.functional.lif import LIFSubFunction
1110
from tether.functional.plif import PLIFSubFunction
1211

12+
1313
def benchmark_plif():
1414
"""
1515
Benchmark LIF vs PLIF (Triton vs Triton).
@@ -25,15 +25,15 @@ def benchmark_plif():
2525
seq_len = 2048
2626
dim = 768
2727
n_neurons = batch_size * dim
28-
28+
2929
x_seq = torch.randn(seq_len, n_neurons, device=device)
3030
v_init = torch.zeros(n_neurons, device=device)
31-
31+
3232
# LIF Params (Scalar)
3333
decay_scalar = torch.tensor(0.9, device=device)
3434
threshold_scalar = torch.tensor(1.0, device=device)
3535
alpha = torch.tensor(2.0, device=device)
36-
36+
3737
# PLIF Params (Vector)
3838
decay_vector = torch.full((n_neurons,), 0.9, device=device)
3939
threshold_vector = torch.full((n_neurons,), 1.0, device=device)
@@ -42,16 +42,22 @@ def benchmark_plif():
4242
print("Warming up...")
4343
for _ in range(10):
4444
with torch.no_grad():
45-
LIFSubFunction.apply(x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0)
46-
PLIFSubFunction.apply(x_seq, v_init, decay_vector, threshold_vector, alpha, 0)
47-
45+
LIFSubFunction.apply(
46+
x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0
47+
)
48+
PLIFSubFunction.apply(
49+
x_seq, v_init, decay_vector, threshold_vector, alpha, 0
50+
)
51+
4852
# Benchmark LIF
4953
torch.cuda.synchronize()
5054
start_time = time.time()
5155
iterations = 50
5256
with torch.no_grad():
5357
for _ in range(iterations):
54-
LIFSubFunction.apply(x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0)
58+
LIFSubFunction.apply(
59+
x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0
60+
)
5561
torch.cuda.synchronize()
5662
lif_time = (time.time() - start_time) / iterations
5763
print(f"LIF Time: {lif_time * 1000:.3f} ms")
@@ -61,12 +67,15 @@ def benchmark_plif():
6167
start_time = time.time()
6268
with torch.no_grad():
6369
for _ in range(iterations):
64-
PLIFSubFunction.apply(x_seq, v_init, decay_vector, threshold_vector, alpha, 0)
70+
PLIFSubFunction.apply(
71+
x_seq, v_init, decay_vector, threshold_vector, alpha, 0
72+
)
6573
torch.cuda.synchronize()
6674
plif_time = (time.time() - start_time) / iterations
6775
print(f"PLIF Time: {plif_time * 1000:.3f} ms")
68-
76+
6977
print(f"Overhead: {plif_time / lif_time:.2f}x slower")
7078

79+
7180
if __name__ == "__main__":
7281
benchmark_plif()

docs/conf.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import re
4+
45
# Ensure Sphinx can find the source code
56
sys.path.insert(0, os.path.abspath("../src"))
67

@@ -37,22 +38,22 @@ def get_project_metadata():
3738

3839
meta = get_project_metadata()
3940

40-
project = 'Tether'
41-
copyright = '2025, Khushiyant'
41+
project = "Tether"
42+
copyright = "2025, Khushiyant"
4243
author = meta.get("author", "Khushiyant")
4344
release = meta.get("release", "0.1.0")
4445

4546
extensions = [
46-
'sphinx.ext.autodoc', # Core library for html generation from docstrings
47-
'sphinx.ext.autosummary', # Create neat summary tables
48-
'sphinx.ext.napoleon', # Support for NumPy/Google style docstrings
49-
'sphinx.ext.viewcode', # Add links to highlighted source code
50-
'myst_parser', # Support for Markdown files
51-
'sphinx_autodoc_typehints' # Show type hints in docs
47+
"sphinx.ext.autodoc", # Core library for html generation from docstrings
48+
"sphinx.ext.autosummary", # Create neat summary tables
49+
"sphinx.ext.napoleon", # Support for NumPy/Google style docstrings
50+
"sphinx.ext.viewcode", # Add links to highlighted source code
51+
"myst_parser", # Support for Markdown files
52+
"sphinx_autodoc_typehints", # Show type hints in docs
5253
]
5354

54-
templates_path = ['_templates']
55-
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
55+
templates_path = ["_templates"]
56+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
5657

5758
html_static_path = ["_static"]
5859
html_theme = "furo"
@@ -65,5 +66,5 @@ def get_project_metadata():
6566
}
6667

6768
suppress_warnings = [
68-
'myst.xref_missing', # Suppress MyST cross-reference errors
69-
]
69+
"myst.xref_missing", # Suppress MyST cross-reference errors
70+
]

examples/train_cifar.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,26 @@
66
from torch.utils.data import DataLoader
77
from tether.nn import LIF
88
from tether.data import SpikingDatasetWrapper, rate_encoding
9-
from tether.utils.monitor import Monitor #
9+
from tether.utils.monitor import Monitor #
10+
1011

1112
class SpikingCIFARModel(nn.Module):
1213
def __init__(self, n_steps=10):
1314
super().__init__()
1415
self.n_steps = n_steps
15-
16+
1617
# Define layers
1718
self.conv_layers = nn.Sequential(
1819
nn.Conv2d(3, 32, kernel_size=3, padding=1),
1920
LIF(32 * 32 * 32),
2021
nn.MaxPool2d(2),
2122
nn.Conv2d(32, 64, kernel_size=3, padding=1),
2223
LIF(64 * 16 * 16),
23-
nn.MaxPool2d(2)
24+
nn.MaxPool2d(2),
2425
)
25-
26+
2627
self.fc_layers = nn.Sequential(
27-
nn.Flatten(),
28-
nn.Linear(64 * 8 * 8, 256),
29-
LIF(256),
30-
nn.Linear(256, 10)
28+
nn.Flatten(), nn.Linear(64 * 8 * 8, 256), LIF(256), nn.Linear(256, 10)
3129
)
3230

3331
def forward(self, x):
@@ -40,6 +38,7 @@ def forward(self, x):
4038
outputs.append(out)
4139
return torch.stack(outputs).mean(0)
4240

41+
4342
def evaluate(model, loader, device):
4443
model.eval()
4544
correct = 0
@@ -52,7 +51,8 @@ def evaluate(model, loader, device):
5251
pred = output.argmax(dim=1)
5352
correct += (pred == target).sum().item()
5453
total += target.size(0)
55-
return 100. * correct / total
54+
return 100.0 * correct / total
55+
5656

5757
def main():
5858
# --- Configuration ---
@@ -63,25 +63,36 @@ def main():
6363
lr = 1e-3
6464

6565
# --- Data Preparation ---
66-
transform = transforms.Compose([
67-
transforms.ToTensor(),
68-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69-
])
70-
71-
train_raw = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
72-
test_raw = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
73-
74-
train_ds = SpikingDatasetWrapper(train_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps))
75-
test_ds = SpikingDatasetWrapper(test_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps))
76-
77-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
78-
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
66+
transform = transforms.Compose(
67+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
68+
)
69+
70+
train_raw = datasets.CIFAR10(
71+
root="./data", train=True, download=True, transform=transform
72+
)
73+
test_raw = datasets.CIFAR10(
74+
root="./data", train=False, download=True, transform=transform
75+
)
76+
77+
train_ds = SpikingDatasetWrapper(
78+
train_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps)
79+
)
80+
test_ds = SpikingDatasetWrapper(
81+
test_raw, encode_fn=lambda x: rate_encoding(x, n_steps=n_steps)
82+
)
83+
84+
train_loader = DataLoader(
85+
train_ds, batch_size=batch_size, shuffle=True, num_workers=4
86+
)
87+
test_loader = DataLoader(
88+
test_ds, batch_size=batch_size, shuffle=False, num_workers=4
89+
)
7990

8091
# --- Model & Monitoring ---
8192
model = SpikingCIFARModel(n_steps=n_steps).to(device)
8293
optimizer = optim.AdamW(model.parameters(), lr=lr)
8394
criterion = nn.CrossEntropyLoss()
84-
95+
8596
# Initialize the Tether Monitor
8697
monitor = Monitor(model)
8798

@@ -115,19 +126,26 @@ def main():
115126
# --- End of Epoch Monitoring ---
116127
epoch_time = time.time() - start_time
117128
avg_loss = running_loss / len(train_loader)
118-
train_acc = 100. * train_correct / train_total
129+
train_acc = 100.0 * train_correct / train_total
119130
val_acc = evaluate(model, test_loader, device)
120131

121132
# Retrieve firing rates from all LIF layers via Monitor
122133
firing_rates = monitor.get_firing_rates()
123134
# Calculate mean firing rate across the entire model
124-
mean_fr = sum(firing_rates.values()) / len(firing_rates) if firing_rates else 0.0
135+
mean_fr = (
136+
sum(firing_rates.values()) / len(firing_rates) if firing_rates else 0.0
137+
)
125138

126139
# Print detailed report once per epoch
127-
print(f"Epoch {epoch+1}/{epochs} | Time: {epoch_time:.1f}s")
128-
print(f" > Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
129-
print(f" > Mean Firing Rate: {mean_fr:.4f} (Sparsity: {(1-mean_fr)*100:.1f}%)")
140+
print(f"Epoch {epoch + 1}/{epochs} | Time: {epoch_time:.1f}s")
141+
print(
142+
f" > Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%"
143+
)
144+
print(
145+
f" > Mean Firing Rate: {mean_fr:.4f} (Sparsity: {(1 - mean_fr) * 100:.1f}%)"
146+
)
130147
print("-" * 60)
131148

149+
132150
if __name__ == "__main__":
133-
main()
151+
main()

0 commit comments

Comments
 (0)