Skip to content

Commit 042f891

Browse files
committed
feat: implement ALIF and PLIF modules with Triton kernels, add surrogate gradient classes, and enhance benchmark functionality
1 parent daf68ab commit 042f891

File tree

12 files changed

+747
-4
lines changed

12 files changed

+747
-4
lines changed

benchmarks/benchmark_plif.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import triton
3+
import time
4+
import sys
5+
import os
6+
7+
# Add src to path
8+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
9+
10+
from tether.functional.lif import LIFSubFunction
11+
from tether.functional.plif import PLIFSubFunction
12+
13+
def benchmark_plif():
14+
"""
15+
Benchmark LIF vs PLIF (Triton vs Triton).
16+
"""
17+
if not torch.cuda.is_available():
18+
print("Skipping benchmark on CPU.")
19+
return
20+
21+
device = torch.device("cuda")
22+
print(f"Benchmarking LIF vs PLIF on {torch.cuda.get_device_name(0)}")
23+
24+
batch_size = 32
25+
seq_len = 2048
26+
dim = 768
27+
n_neurons = batch_size * dim
28+
29+
x_seq = torch.randn(seq_len, n_neurons, device=device)
30+
v_init = torch.zeros(n_neurons, device=device)
31+
32+
# LIF Params (Scalar)
33+
decay_scalar = torch.tensor(0.9, device=device)
34+
threshold_scalar = torch.tensor(1.0, device=device)
35+
alpha = torch.tensor(2.0, device=device)
36+
37+
# PLIF Params (Vector)
38+
decay_vector = torch.full((n_neurons,), 0.9, device=device)
39+
threshold_vector = torch.full((n_neurons,), 1.0, device=device)
40+
41+
# Warmup
42+
print("Warming up...")
43+
for _ in range(10):
44+
with torch.no_grad():
45+
LIFSubFunction.apply(x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0)
46+
PLIFSubFunction.apply(x_seq, v_init, decay_vector, threshold_vector, alpha, 0)
47+
48+
# Benchmark LIF
49+
torch.cuda.synchronize()
50+
start_time = time.time()
51+
iterations = 50
52+
with torch.no_grad():
53+
for _ in range(iterations):
54+
LIFSubFunction.apply(x_seq, v_init, decay_scalar, threshold_scalar, alpha, 0)
55+
torch.cuda.synchronize()
56+
lif_time = (time.time() - start_time) / iterations
57+
print(f"LIF Time: {lif_time * 1000:.3f} ms")
58+
59+
# Benchmark PLIF
60+
torch.cuda.synchronize()
61+
start_time = time.time()
62+
with torch.no_grad():
63+
for _ in range(iterations):
64+
PLIFSubFunction.apply(x_seq, v_init, decay_vector, threshold_vector, alpha, 0)
65+
torch.cuda.synchronize()
66+
plif_time = (time.time() - start_time) / iterations
67+
print(f"PLIF Time: {plif_time * 1000:.3f} ms")
68+
69+
print(f"Overhead: {plif_time / lif_time:.2f}x slower")
70+
71+
if __name__ == "__main__":
72+
benchmark_plif()

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_project_metadata():
4040
project = 'Tether'
4141
copyright = '2025, Khushiyant'
4242
author = meta.get("author", "Khushiyant")
43-
release = meta.get("release", "0.6.0")
43+
release = meta.get("release", "0.1.0")
4444

4545
extensions = [
4646
'sphinx.ext.autodoc', # Core library for html generation from docstrings

src/tether/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
from .nn.lif import LIF
2+
from .nn.alif import ALIF
3+
from .nn.plif import PLIF
24
from .nn.attention import SpikingSelfAttention
35
from .nn.block import SpikingTransformerBlock
6+
from .nn.surrogates import Surrogate, Arctan, Sigmoid, FastSigmoid
47

5-
__all__ = ["LIF", "SpikingSelfAttention", "SpikingTransformerBlock"]
8+
__all__ = [
9+
"LIF",
10+
"ALIF",
11+
"PLIF",
12+
"SpikingSelfAttention",
13+
"SpikingTransformerBlock",
14+
"Surrogate",
15+
"Arctan",
16+
"Sigmoid",
17+
"FastSigmoid"
18+
]

src/tether/data/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .encoding import SpikingDatasetWrapper, rate_encoding, latency_encoding
2+
3+
__all__ = ["SpikingDatasetWrapper", "rate_encoding", "latency_encoding"]

src/tether/functional/alif.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import triton
3+
from ..kernels.alif import alif_fwd_kernel, alif_bwd_kernel
4+
5+
class ALIFSubFunction(torch.autograd.Function):
6+
@staticmethod
7+
def forward(ctx, x_seq, v_init, a_init, decay_v, decay_a, threshold, beta, alpha):
8+
"""
9+
Forward pass of the ALIF function.
10+
"""
11+
x_seq, v_init, a_init = x_seq.contiguous(), v_init.contiguous(), a_init.contiguous()
12+
n_steps, n_neurons = x_seq.shape
13+
14+
out_spikes = torch.empty_like(x_seq)
15+
16+
n_int32 = (n_steps + 31) // 32
17+
out_spikes_packed = torch.empty((n_int32, n_neurons), device=x_seq.device, dtype=torch.int32)
18+
19+
v_seq = torch.empty_like(x_seq)
20+
a_seq = torch.empty_like(x_seq)
21+
v_final = torch.empty_like(v_init)
22+
a_final = torch.empty_like(a_init)
23+
24+
grid = (triton.cdiv(n_neurons, 1024),)
25+
alif_fwd_kernel[grid](
26+
x_seq, v_init, a_init,
27+
out_spikes, out_spikes_packed, v_seq, v_final, a_seq, a_final,
28+
n_neurons, n_steps, decay_v.item(), decay_a.item(), threshold.item(), beta.item(),
29+
BLOCK_SIZE=1024
30+
)
31+
32+
ctx.save_for_backward(out_spikes_packed, v_seq, a_seq, v_init, a_init, decay_v, decay_a, threshold, beta, alpha)
33+
ctx.mark_non_differentiable(v_seq)
34+
ctx.mark_non_differentiable(a_seq)
35+
return out_spikes, v_final, a_final, v_seq, a_seq
36+
37+
@staticmethod
38+
def backward(ctx, grad_spikes, grad_v_final, grad_a_final, grad_v_seq, grad_a_seq):
39+
"""
40+
Backward pass of the ALIF function.
41+
"""
42+
out_spikes_packed, v_seq, a_seq, v_init, a_init, decay_v, decay_a, threshold, beta, alpha = ctx.saved_tensors
43+
n_steps, n_neurons = v_seq.shape
44+
45+
grad_x = torch.empty_like(v_seq)
46+
47+
grad_decay_v = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
48+
grad_decay_a = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
49+
grad_threshold = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
50+
grad_beta = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
51+
grad_alpha = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
52+
53+
if grad_v_final is None:
54+
grad_v_final = torch.zeros_like(v_init)
55+
if grad_a_final is None:
56+
grad_a_final = torch.zeros_like(a_init)
57+
58+
grid = (triton.cdiv(n_neurons, 1024),)
59+
60+
alif_bwd_kernel[grid](
61+
grad_spikes.contiguous(), out_spikes_packed,
62+
v_seq.contiguous(), a_seq.contiguous(),
63+
grad_x,
64+
grad_v_final.contiguous(), grad_a_final.contiguous(),
65+
v_init.contiguous(), a_init.contiguous(),
66+
n_neurons, n_steps,
67+
decay_v, decay_a, threshold, beta, alpha,
68+
grad_decay_v, grad_decay_a, grad_threshold, grad_beta, grad_alpha,
69+
BLOCK_SIZE=1024
70+
)
71+
72+
return grad_x, grad_v_final, grad_a_final, grad_decay_v, grad_decay_a, grad_threshold, grad_beta, grad_alpha

src/tether/functional/plif.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import triton
3+
from ..kernels.plif import plif_fwd_kernel, plif_bwd_kernel
4+
5+
class PLIFSubFunction(torch.autograd.Function):
6+
@staticmethod
7+
def forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type):
8+
"""
9+
Forward pass of the PLIF function.
10+
Decay and Threshold are vectors (n_neurons,).
11+
"""
12+
x_seq, v_init = x_seq.contiguous(), v_init.contiguous()
13+
decay, threshold = decay.contiguous(), threshold.contiguous()
14+
15+
n_steps, n_neurons = x_seq.shape
16+
17+
out_spikes = torch.empty_like(x_seq)
18+
n_int32 = (n_steps + 31) // 32
19+
out_spikes_packed = torch.empty((n_int32, n_neurons), device=x_seq.device, dtype=torch.int32)
20+
21+
v_seq = torch.empty_like(x_seq)
22+
v_final = torch.empty_like(v_init)
23+
24+
grid = (triton.cdiv(n_neurons, 1024),)
25+
plif_fwd_kernel[grid](
26+
x_seq, v_init, out_spikes, out_spikes_packed, v_seq, v_final,
27+
n_neurons, n_steps, decay, threshold,
28+
BLOCK_SIZE=1024
29+
)
30+
31+
ctx.save_for_backward(out_spikes_packed, v_seq, v_init, decay, threshold, alpha)
32+
ctx.surrogate_type = surrogate_type
33+
ctx.mark_non_differentiable(v_seq)
34+
return out_spikes, v_final, v_seq
35+
36+
@staticmethod
37+
def backward(ctx, grad_spikes, grad_v_final, grad_v_seq):
38+
out_spikes_packed, v_seq, v_init, decay, threshold, alpha = ctx.saved_tensors
39+
surrogate_type = ctx.surrogate_type
40+
n_steps, n_neurons = v_seq.shape
41+
42+
grad_x = torch.empty_like(v_seq)
43+
44+
# Gradients for parameters (Vectors)
45+
grad_decay = torch.zeros_like(decay)
46+
grad_threshold = torch.zeros_like(threshold)
47+
grad_alpha = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32)
48+
49+
if grad_v_final is None:
50+
grad_v_final = torch.zeros_like(v_init)
51+
52+
grid = (triton.cdiv(n_neurons, 1024),)
53+
54+
plif_bwd_kernel[grid](
55+
grad_spikes.contiguous(), out_spikes_packed,
56+
v_seq.contiguous(), grad_x,
57+
grad_v_final.contiguous(), v_init.contiguous(),
58+
n_neurons, n_steps, decay, threshold, alpha,
59+
grad_decay, grad_threshold, grad_alpha,
60+
surrogate_type,
61+
BLOCK_SIZE=1024
62+
)
63+
64+
return grad_x, grad_v_final, grad_decay, grad_threshold, grad_alpha, None

0 commit comments

Comments
 (0)