Skip to content

Commit 2f8ddf2

Browse files
Update phase1_benchmark.py
1 parent 52a925f commit 2f8ddf2

File tree

1 file changed

+136
-154
lines changed

1 file changed

+136
-154
lines changed
Lines changed: 136 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
11
"""
2-
Adaptive Spiking Windows Implementation – Phase 1 Complete
3-
Includes:
4-
1. Vectorized masked attention via torch.einsum
5-
2. Warm‑up + fine‑tune epoch schedule
6-
3. Unit test for S=4, T=5
7-
4. Benchmark speed & memory
2+
Adaptive Spiking Windows Implementation
3+
Phase 1: Token-wise Temporal Allocation for Spiking Transformers
4+
+ vectorized masked attention (einsum)
5+
+ unit test for S=4, T=5
6+
+ speed/memory benchmarking
87
"""
98

10-
import time
11-
import itertools
129
import torch
1310
import torch.nn as nn
1411
import torch.nn.functional as F
15-
from torch.nn.utils.rnn import pad_sequence
12+
import time
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
from typing import Tuple, Optional, Dict
1616

17-
# -----------------------------------------------------------------------------
18-
# LIF Neuron Definition
19-
# -----------------------------------------------------------------------------
2017
class LIFNeuron(nn.Module):
18+
"""Leaky Integrate-and-Fire neuron with learnable decay"""
2119
def __init__(self, tau_mem=20.0, tau_syn=5.0, v_threshold=1.0, v_reset=0.0):
2220
super().__init__()
23-
self.beta = nn.Parameter(torch.tensor(torch.exp(-1/tau_mem)))
24-
self.alpha = nn.Parameter(torch.tensor(torch.exp(-1/tau_syn)))
21+
self.beta = nn.Parameter(torch.tensor(np.exp(-1/tau_mem)))
22+
self.alpha = nn.Parameter(torch.tensor(np.exp(-1/tau_syn)))
2523
self.v_threshold = v_threshold
2624
self.v_reset = v_reset
2725

2826
def forward(self, x, state=None):
27+
# x: [B=1 or B, D], state: (v_mem, i_syn)
2928
if state is None:
3029
v_mem = torch.zeros_like(x)
3130
i_syn = torch.zeros_like(x)
@@ -38,169 +37,152 @@ def forward(self, x, state=None):
3837
return spikes, (v_mem, i_syn)
3938

4039

41-
# -----------------------------------------------------------------------------
42-
# Adaptive Spiking Attention – Vectorized
43-
# -----------------------------------------------------------------------------
4440
class AdaptiveSpikingAttention(nn.Module):
45-
def __init__(self, embedding_dim, num_heads=8, T_max=20, lambda_reg=1e-3, dropout=0.1):
41+
def __init__(self, embedding_dim, num_heads=4, T_max=20, lambda_reg=1e-3, dropout=0.1):
4642
super().__init__()
4743
assert embedding_dim % num_heads == 0
48-
self.embedding_dim = embedding_dim
49-
self.num_heads = num_heads
50-
self.head_dim = embedding_dim // num_heads
44+
self.D = embedding_dim
45+
self.H = num_heads
46+
self.Dh = embedding_dim // num_heads
5147
self.T_max = T_max
5248
self.lambda_reg = lambda_reg
53-
self.scale = self.head_dim ** -0.5
54-
5549
# projections
56-
self.q_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
57-
self.k_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
58-
self.v_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
59-
self.out_proj = nn.Linear(embedding_dim, embedding_dim)
60-
61-
# spiking
50+
self.q_proj = nn.Linear(self.D, self.D, bias=False)
51+
self.k_proj = nn.Linear(self.D, self.D, bias=False)
52+
self.v_proj = nn.Linear(self.D, self.D, bias=False)
53+
self.out_proj = nn.Linear(self.D, self.D)
54+
# spiking neurons
6255
self.lif_q = LIFNeuron()
6356
self.lif_k = LIFNeuron()
6457
self.lif_v = LIFNeuron()
65-
6658
# gating
6759
self.window_gate = nn.Sequential(
68-
nn.Linear(embedding_dim, 64), nn.ReLU(),
69-
nn.Linear(64, 32), nn.ReLU(),
70-
nn.Linear(32, 1), nn.Sigmoid()
71-
)
72-
self.complexity_estimator = nn.Sequential(
73-
nn.Linear(embedding_dim, 32), nn.ReLU(),
60+
nn.Linear(self.D, 32), nn.ReLU(),
7461
nn.Linear(32, 1), nn.Sigmoid()
7562
)
63+
# dropout & scale
7664
self.dropout = nn.Dropout(dropout)
65+
self.scale = self.Dh ** -0.5
66+
self.T_history = []
7767

78-
def get_adaptive_windows(self, x):
79-
gate = self.window_gate(x) # [B,S,1]
80-
comp = self.complexity_estimator(x) # [B,S,1]
81-
combined = 0.7 * gate + 0.3 * comp
82-
T_i = torch.ceil(combined.squeeze(-1) * self.T_max).clamp(1, self.T_max).long()
83-
return T_i # [B,S]
68+
def get_windows(self, x):
69+
# x: [B, S, D]
70+
gate = self.window_gate(x).squeeze(-1) # [B, S]
71+
Ti = (gate * self.T_max).ceil().clamp(1, self.T_max).long()
72+
return Ti
8473

85-
def generate_adaptive_spikes(self, proj, x, T_i):
74+
def generate_spikes(self, x, Ti, lif):
75+
# x: [B, S, D] -> spikes: [B, S, T, D]
8676
B, S, D = x.shape
87-
spikes = torch.zeros(B, S, self.T_max, D, device=x.device)
77+
out = x.new_zeros(B, S, self.T_max, D)
8878
for b in range(B):
89-
for i in range(S):
79+
for s in range(S):
9080
state = None
91-
for t in range(T_i[b, i]):
92-
s, state = proj(x[b, i:i+1], state)
93-
spikes[b, i, t] = s
94-
return spikes
95-
96-
def masked_einsum_attention(self, q_spikes, k_spikes, v_spikes, T_i):
97-
B, S, T, H, Dh = q_spikes.shape
98-
# mask: [B,S,T]
99-
arange = torch.arange(T, device=T_i.device)
100-
mask = (arange[None, None, :] < T_i[:, :, None]).float()
101-
81+
for t in range(Ti[b, s]):
82+
spk, state = lif(x[b:b+1, s], state)
83+
out[b, s, t] = spk
84+
return out # [B, S, T_max, D]
85+
86+
def vectorized_attention(self, q_spk, k_spk, v_spk, Ti):
87+
# q_spk,k_spk,v_spk: [B, S, T, D]; reshape -> [B, S, T, H, Dh]
88+
B, S, T, D = q_spk.shape
89+
H, Dh = self.H, self.Dh
90+
q = q_spk.view(B, S, T, H, Dh)
91+
k = k_spk.view(B, S, T, H, Dh)
92+
v = v_spk.view(B, S, T, H, Dh)
93+
# mask: [B, S, T]
94+
mask = (torch.arange(T, device=Ti.device)[None, None, :] < Ti[:, :, None]).float()
10295
# apply mask
103-
m = mask[:, :, :, None, None] # [B,S,T,1,1]
104-
qm = q_spikes * m
105-
km = k_spikes * m
106-
107-
# compute raw scores: [B,H,S,S]
108-
S_raw = torch.einsum('bithd,bjthd->bhij', qm, km)
109-
scores = S_raw * self.scale
110-
weights = F.softmax(scores, dim=-1)
111-
weights = self.dropout(weights)
112-
113-
# mean-over-time values: [B,S,H,Dh]
114-
v_mean = v_spikes.mean(dim=2).view(B, S, H, Dh).transpose(1, 2)
115-
out = torch.matmul(weights, v_mean) # [B,H,S,Dh]
116-
out = out.transpose(1,2).contiguous().view(B, S, H*Dh)
117-
return self.out_proj(out), weights
118-
119-
def compute_reg_loss(self, T_i):
120-
return self.lambda_reg * T_i.float().mean()
96+
mask4 = mask[:, :, :, None, None] # [B,S,T,1,1]
97+
q = q * mask4
98+
k = k * mask4
99+
v = v * mask4
100+
# score: [B,H,S,S]
101+
Sraw = torch.einsum('bithd,bjthd->bhij', q, k) * self.scale
102+
W = F.softmax(Sraw, dim=-1)
103+
W = self.dropout(W)
104+
# aggregate v: first mean over time -> [B,S,H,Dh], then attention
105+
v_mean = v.mean(dim=2).transpose(1, 2) # [B,H,S,Dh]
106+
out = torch.einsum('bhij,bhjd->bhid', W, v_mean) # [B,H,S,Dh]
107+
out = out.transpose(1, 2).reshape(B, S, D)
108+
return self.out_proj(out), W
121109

122110
def forward(self, x):
123111
B, S, D = x.shape
124-
# projections
125-
q = self.q_proj(x).view(B, S, self.num_heads, -1)
126-
k = self.k_proj(x).view(B, S, self.num_heads, -1)
127-
v = self.v_proj(x).view(B, S, self.num_heads, -1)
128-
129-
# windows and spikes
130-
T_i = self.get_adaptive_windows(x) # [B,S]
131-
q_sp = self.generate_adaptive_spikes(self.lif_q, q, T_i)
132-
k_sp = self.generate_adaptive_spikes(self.lif_k, k, T_i)
133-
v_sp = self.generate_adaptive_spikes(self.lif_v, v, T_i)
134-
112+
# project
113+
q = self.q_proj(x); k = self.k_proj(x); v = self.v_proj(x)
114+
# windows
115+
Ti = self.get_windows(x) # [B,S]
116+
# spikes
117+
q_spk = self.generate_spikes(q, Ti, self.lif_q)
118+
k_spk = self.generate_spikes(k, Ti, self.lif_k)
119+
v_spk = self.generate_spikes(v, Ti, self.lif_v)
135120
# attention
136-
out, attn = self.masked_einsum_attention(q_sp, k_sp, v_sp, T_i)
137-
reg = self.compute_reg_loss(T_i)
138-
return out, attn, reg, T_i
139-
140-
141-
# -----------------------------------------------------------------------------
142-
# Unit Test: S=4, T=5
143-
# -----------------------------------------------------------------------------
144-
def brute_force(q, k, T_i):
145-
B,S,T,H,D = q.shape
146-
S_loop = torch.zeros(B,H,S,S)
121+
out, W = self.vectorized_attention(q_spk, k_spk, v_spk, Ti)
122+
# reg loss
123+
reg = self.lambda_reg * Ti.float().mean()
124+
# log
125+
if self.training:
126+
self.T_history.append(Ti.cpu().numpy())
127+
return out, {'reg_loss': reg, 'Ti': Ti, 'W': W}
128+
129+
# --- Unit Test & Benchmark --------------------------------------------------
130+
131+
def brute_force(q, k, Ti):
132+
B, S, T, H, Dh = q.shape
133+
S1 = torch.zeros(B, H, S, S)
147134
for b in range(B):
148135
for h in range(H):
149-
for i,j in itertools.product(range(S),range(S)):
150-
tm = min(T_i[b,i], T_i[b,j])
151-
val = 0.
152-
for t in range(tm):
153-
val += (q[b,i,t,h]*k[b,j,t,h]).sum()
154-
S_loop[b,h,i,j] = val
155-
return S_loop
156-
157-
# test
158-
B,S,T,H,D = 1,4,5,2,3
159-
q = torch.randn(B,S,T,H,D)
160-
k = torch.randn_like(q)
161-
T_i = torch.randint(1, T+1, (B,S))
162-
# brute
163-
S1 = brute_force(q,k,T_i)
164-
# vectorized
165-
mask = (torch.arange(T)[None,None,:] < T_i[:,:,None]).float()
166-
qm = q * mask[:,:,:,None,None]
167-
km = k * mask[:,:,:,None,None]
168-
S2 = torch.einsum('bithd,bjthd->bhij', qm, km)
169-
assert torch.allclose(S1, S2, atol=1e-6), "Mismatch!"
170-
print("✅ Unit test passed: vectorized == brute force")
171-
172-
# -----------------------------------------------------------------------------
173-
# Benchmark Speed & Memory
174-
# -----------------------------------------------------------------------------
175-
model = AdaptiveSpikingAttention(embedding_dim=32, num_heads=2, T_max=5)
176-
x = torch.randn(2, 10, 32)
177-
178-
# warm-up
179-
for _ in range(10):
180-
_ = model(x)
181-
182-
# benchmark
183-
start = time.perf_counter()
184-
for _ in range(50):
185-
_ = model(x)
186-
t_vec = time.perf_counter() - start
187-
188-
# brute-force benchmark
189-
def bf_forward(x):
190-
# only attention part
191-
q = model.q_proj(x).view(2,10,2,-1)
192-
k = model.k_proj(x).view(2,10,2,-1)
193-
T_i = model.get_adaptive_windows(x)
194-
q_sp = model.generate_adaptive_spikes(model.lif_q, q, T_i)
195-
k_sp = model.generate_adaptive_spikes(model.lif_k, k, T_i)
196-
# brute compute
197-
_ = brute_force(q_sp, k_sp, T_i)
198-
return _
199-
200-
start = time.perf_counter()
201-
for _ in range(50):
202-
_ = bf_forward(x)
203-
t_bf = time.perf_counter() - start
204-
205-
print(f"✅ Vectorized forward (50 runs): {t_vec:.3f}s")
206-
print(f"❌ Brute-force (50 runs): {t_bf:.3f}s")
136+
for i in range(S):
137+
for j in range(S):
138+
tlim = min(Ti[b,i], Ti[b,j]).item()
139+
val = 0.0
140+
for t in range(tlim):
141+
val += (q[b,i,t,h] * k[b,j,t,h]).sum()
142+
S1[b,h,i,j] = val
143+
return S1
144+
145+
if __name__ == "__main__":
146+
# test shapes
147+
B,S,T,H,Dh = 1,4,5,2,3
148+
D = H*Dh
149+
model = AdaptiveSpikingAttention(D, num_heads=H, T_max=T)
150+
# fake data
151+
x = torch.randn(B, S, D)
152+
q = torch.randn(B, S, T, H, Dh)
153+
k = torch.randn_like(q)
154+
Ti = torch.randint(1, T+1, (B,S))
155+
# brute vs vectorized
156+
bf = brute_force(q, k, Ti)
157+
vec = model.vectorized_attention(q.view(B,S,T,D), k.view(B,S,T,D),
158+
torch.randn(B,S,T,D).view(B,S,T,D), Ti)[1]
159+
# we only compare raw scores before softmax:
160+
# extract raw Sraw from vectorized code manually
161+
# (re-run vectorized_attention but output raw Sraw)
162+
def raw_vec(q_spk,k_spk,Ti):
163+
# q_spk, k_spk: [B, S, T, H, Dh]
164+
B, S, T, H, Dh = q_spk.shape
165+
q_ = q_spk
166+
k_ = k_spk
167+
mask = (torch.arange(T)[None, None, :] < Ti[:, :, None]).float()
168+
q_ = q_ * mask[:, :, :, None, None]
169+
k_ = k_ * mask[:, :, :, None, None]
170+
return torch.einsum('bithd,bjthd->bhij', q_, k_)
171+
rv = raw_vec(q, k, Ti)
172+
assert torch.allclose(bf, rv, atol=1e-5)
173+
print("✅ Unit test passed (S=4, T=5)")
174+
175+
# Benchmark
176+
reps = 100
177+
start = time.perf_counter()
178+
for _ in range(reps):
179+
brute_force(q,k,Ti)
180+
t1 = time.perf_counter() - start
181+
182+
start = time.perf_counter()
183+
for _ in range(reps):
184+
_ = raw_vec(q,k,Ti)
185+
t2 = time.perf_counter() - start
186+
187+
print(f"Brute force: {t1:.4f}s for {reps} runs")
188+
print(f"Vectorized (raw): {t2:.4f}s for {reps} runs")

0 commit comments

Comments
 (0)