Skip to content

Commit daf68ab

Browse files
committed
chore: enhance documentation, add new encoding utilities, and improve LIF functionality with surrogate gradients
1 parent 2ba9444 commit daf68ab

File tree

11 files changed

+573
-34
lines changed

11 files changed

+573
-34
lines changed

README.md

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
<img width="1000" height="600" alt="Gemini_Generated_Image_xofloxxofloxxofl" src="https://github.com/user-attachments/assets/420a486f-1a09-4d72-a98b-22678abd0e75" />
33

44

5-
**Tether** is a Triton-powered framework for training and deploying **Spiking Transformers**.
5+
**Tether** is a Triton-powered framework for training and deploying **Spiking Transformers** and deep Spiking Neural Networks (SNNs).
66

7-
We’ve solved the non-differentiability of discrete spikes by implementing a custom **Arctan Surrogate Gradient** in the autograd backward pass.
7+
We’ve solved the non-differentiability of discrete spikes by implementing high-performance Triton kernels with modular **Surrogate Gradients**.
88

99
## Key Features
1010

11-
- **Fused LIF Kernel**: Manages membrane potential statefulness across temporal windows without global memory stalls, utilizing Triton for high-performance GPU execution.
12-
- **Linear Spike-Driven Attention**: Eliminates the $O(N^2)$ Softmax bottleneck, allowing for massive context windows with significantly lower energy per inference (Joules/op).
13-
- **Bit-Packing** (In Progress): Optimization for memory-efficient spike storage.
14-
- **Triton-Powered**: Leverages OpenAI's Triton language for custom CUDA kernels.
11+
- **High-Performance Neurons**:
12+
- **LIF (Leaky Integrate-and-Fire)**: Standard spiking neuron with fused Triton kernels.
13+
- **ALIF (Adaptive LIF)**: Neurons with adaptive thresholds for better temporal dynamics.
14+
- **PLIF (Parametric LIF)**: Neurons with learnable, per-channel decay and threshold parameters.
15+
- **Modular Surrogate Gradients**: Choose from `Arctan`, `Sigmoid`, or `FastSigmoid` to train your SNNs effectively.
16+
- **Linear Spike-Driven Attention**: Eliminates the $O(N^2)$ Softmax bottleneck, allowing for massive context windows with significantly lower energy per inference.
17+
- **Data Utilities**: `SpikingDatasetWrapper` and encoding functions (`rate_encoding`, `latency_encoding`) to convert static datasets to spike trains.
18+
- **Triton-Powered**: Leverages OpenAI's Triton language for custom CUDA kernels, enabling massive speedups (60x+) over vanilla PyTorch.
1519

1620
## Installation
1721

@@ -29,6 +33,25 @@ pip install torch triton numpy
2933

3034
## Usage
3135

36+
### Using PLIF with Sigmoid Surrogate
37+
38+
```python
39+
import torch
40+
from tether import PLIF, Sigmoid
41+
42+
# Create a Parametric LIF layer with Sigmoid surrogate
43+
# Decay and threshold are learnable vectors per neuron
44+
layer = PLIF(
45+
n_neurons=128,
46+
init_decay=0.9,
47+
surrogate=Sigmoid(alpha=4.0)
48+
).cuda()
49+
50+
# Input sequence: (Time, Batch, Neurons)
51+
x = torch.randn(32, 16, 128).cuda()
52+
spikes = layer(x)
53+
```
54+
3255
### Training a Spiking Language Model
3356

3457
The `train_stories.py` script demonstrates training a **Spiking-LLM** on the TinyShakespeare dataset.
@@ -37,20 +60,26 @@ The `train_stories.py` script demonstrates training a **Spiking-LLM** on the Tin
3760
python train_stories.py
3861
```
3962

40-
This will:
41-
1. Download the `input.txt` dataset.
42-
2. Initialize a Tether Spiking Transformer (4 layers, 8 heads).
43-
3. Train using the custom Arctan Surrogate Gradient.
44-
4. Generate sample text from the Spiking SNN.
63+
### Data Encoding
64+
65+
```python
66+
from tether.data import SpikingDatasetWrapper, rate_encoding
67+
from torchvision.datasets import MNIST
68+
69+
# Wrap MNIST to output spike trains
70+
spiking_mnist = SpikingDatasetWrapper(
71+
MNIST(root="./data", download=True, train=True),
72+
encode_fn=lambda x: rate_encoding(x, n_steps=10)
73+
)
74+
```
4575

4676
## Architecture
4777

48-
- **`tether.kernels.lif`**: Custom Triton kernels for Leaky Integrate-and-Fire (LIF) forward and backward passes.
49-
- **`tether.functional.lif`**: PyTorch autograd function wrapping the Triton kernels.
50-
- **`tether.nn.attention`**: Linear Spike-Driven Attention mechanism.
51-
- **`tether.nn.block`**: Spiking Transformer Block implementation.
78+
- **`tether.kernels`**: Custom Triton kernels for LIF, ALIF, and PLIF.
79+
- **`tether.functional`**: PyTorch autograd functions wrapping the Triton kernels.
80+
- **`tether.nn`**: Neural network modules including `LIF`, `ALIF`, `PLIF`, `SpikingSelfAttention`.
81+
- **`tether.data`**: Utilities for spike encoding and dataset wrapping.
5282

5383
## License
5484

55-
[Apache-2.0](https://github.com/Khushiyant/tether/blob/main/LICENSE)
56-
85+
[Apache-2.0](https://github.com/Khushiyant/tether/blob/main/LICENSE)

benchmarks/benchmark_lif.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def benchmark():
8080
for _ in range(10):
8181
with torch.no_grad():
8282
_ = lif_pytorch(x_seq, v_init, decay, threshold)
83-
LIFSubFunction.apply(x_seq, v_init, decay, threshold, alpha)
83+
LIFSubFunction.apply(x_seq, v_init, decay, threshold, alpha, 0)
8484

8585
# Benchmark PyTorch
8686
torch.cuda.synchronize()
@@ -99,7 +99,7 @@ def benchmark():
9999
with torch.no_grad():
100100
for _ in range(iterations):
101101
# Note: We use apply but inside no_grad, so it just runs forward
102-
LIFSubFunction.apply(x_seq, v_init, decay, threshold, alpha)
102+
LIFSubFunction.apply(x_seq, v_init, decay, threshold, alpha, 0)
103103
torch.cuda.synchronize()
104104
triton_time = (time.time() - start_time) / iterations
105105
print(f"Triton Time: {triton_time * 1000:.3f} ms")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ requires = ["hatchling"]
1212
build-backend = "hatchling.build"
1313

1414
[project.optional-dependencies]
15-
dev = ["pytest>=9.0.2"]
15+
dev = ["pytest>=9.0.2", "pytest-cov>=7.0.0"]
1616
docs = [
1717
"sphinx>=8.1.3",
1818
"sphinx-autodoc-typehints>=3.0.1",

src/tether/data/encoding.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
4+
class SpikingDatasetWrapper(Dataset):
5+
"""
6+
Wraps a standard dataset and applies an encoding function to the input.
7+
"""
8+
def __init__(self, dataset: Dataset, encode_fn):
9+
self.dataset = dataset
10+
self.encode_fn = encode_fn
11+
12+
def __len__(self):
13+
return len(self.dataset)
14+
15+
def __getitem__(self, idx):
16+
x, y = self.dataset[idx]
17+
return self.encode_fn(x), y
18+
19+
def rate_encoding(x: torch.Tensor, n_steps: int, gain: float = 1.0) -> torch.Tensor:
20+
"""
21+
Convert continuous values to spike trains using rate encoding (Bernoulli).
22+
23+
Parameters
24+
----------
25+
x : torch.Tensor
26+
Input tensor with continuous values (usually in [0, 1]).
27+
n_steps : int
28+
Number of time steps to simulate.
29+
gain : float
30+
Scaling factor for firing probability.
31+
32+
Returns
33+
-------
34+
torch.Tensor
35+
Spike tensor with shape (n_steps, *x.shape).
36+
"""
37+
shape = (n_steps,) + x.shape
38+
prob = torch.clamp(x * gain, 0.0, 1.0)
39+
# Expand prob to time dimension
40+
prob = prob.unsqueeze(0).expand(shape)
41+
42+
# Generate spikes
43+
spikes = torch.rand(shape, device=x.device) < prob
44+
return spikes.float()
45+
46+
def latency_encoding(x: torch.Tensor, n_steps: int, tau: float = 1.0, threshold: float = 0.01) -> torch.Tensor:
47+
"""
48+
Convert continuous values to spike trains using latency encoding.
49+
Higher values fire earlier.
50+
51+
Parameters
52+
----------
53+
x : torch.Tensor
54+
Input tensor.
55+
n_steps : int
56+
Number of time steps.
57+
tau : float
58+
Time constant.
59+
threshold : float
60+
Threshold below which no spike is generated.
61+
62+
Returns
63+
-------
64+
torch.Tensor
65+
Spike tensor with shape (n_steps, *x.shape).
66+
"""
67+
# Calculate fire time: t_f = tau * ln(x / (x - theta)) ?
68+
# Or simplified: t_f = (1 - x) * n_steps
69+
70+
# Linear latency:
71+
# 1.0 -> step 0
72+
# 0.0 -> step n_steps-1
73+
74+
x = torch.clamp(x, 0.0, 1.0)
75+
fire_step = ((1.0 - x) * (n_steps - 1)).long()
76+
77+
spikes = torch.zeros((n_steps,) + x.shape, device=x.device)
78+
79+
# Create a grid of time steps
80+
time_grid = torch.arange(n_steps, device=x.device).reshape((n_steps,) + (1,) * x.ndim)
81+
82+
# Spike where time matches fire_step and x > threshold
83+
active = x > threshold
84+
spikes = (time_grid == fire_step) & active.unsqueeze(0)
85+
86+
return spikes.float()

src/tether/functional/lif.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
class LIFSubFunction(torch.autograd.Function):
66
@staticmethod
7-
def forward(ctx, x_seq, v_init, decay, threshold, alpha):
7+
def forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type):
88
"""
99
Forward pass of the LIF function.
1010
@@ -22,6 +22,8 @@ def forward(ctx, x_seq, v_init, decay, threshold, alpha):
2222
Spiking threshold.
2323
alpha : torch.Tensor
2424
Surrogate gradient parameter.
25+
surrogate_type : int
26+
Type of surrogate gradient.
2527
2628
Returns
2729
-------
@@ -49,10 +51,12 @@ def forward(ctx, x_seq, v_init, decay, threshold, alpha):
4951

5052
# Save packed spikes for backward to save memory
5153
ctx.save_for_backward(out_spikes_packed, v_seq, v_init, decay, threshold, alpha)
52-
return out_spikes, v_final
54+
ctx.surrogate_type = surrogate_type
55+
ctx.mark_non_differentiable(v_seq)
56+
return out_spikes, v_final, v_seq
5357

5458
@staticmethod
55-
def backward(ctx, grad_spikes, grad_v_final):
59+
def backward(ctx, grad_spikes, grad_v_final, grad_v_seq):
5660
"""
5761
Backward pass of the LIF function.
5862
@@ -64,13 +68,16 @@ def backward(ctx, grad_spikes, grad_v_final):
6468
Gradients w.r.t. spikes.
6569
grad_v_final : torch.Tensor
6670
Gradients w.r.t. final membrane potentials.
71+
grad_v_seq : torch.Tensor
72+
Gradients w.r.t. voltage sequence.
6773
6874
Returns
6975
-------
7076
tuple
7177
Gradients w.r.t. inputs and parameters.
7278
"""
7379
out_spikes_packed, v_seq, v_init, decay, threshold, alpha = ctx.saved_tensors
80+
surrogate_type = ctx.surrogate_type
7481
n_steps, n_neurons = v_seq.shape
7582

7683
grad_x = torch.empty_like(v_seq)
@@ -91,8 +98,9 @@ def backward(ctx, grad_spikes, grad_v_final):
9198
grad_v_final.contiguous(), v_init.contiguous(),
9299
n_neurons, n_steps, decay, threshold, alpha,
93100
grad_decay, grad_threshold, grad_alpha,
101+
surrogate_type,
94102
BLOCK_SIZE=1024
95103
)
96104

97-
# Returns grads for (x_seq, v_init, decay, threshold, alpha)
98-
return grad_x, grad_v_final, grad_decay, grad_threshold, grad_alpha
105+
# Returns grads for (x_seq, v_init, decay, threshold, alpha, surrogate_type)
106+
return grad_x, grad_v_final, grad_decay, grad_threshold, grad_alpha, None

0 commit comments

Comments
 (0)