11"""
2- Adaptive Spiking Windows Implementation – Phase 1 Complete
3- Includes:
4- 1. Vectorized masked attention via torch.einsum
5- 2. Warm‑up + fine‑tune epoch schedule
6- 3. Unit test for S=4, T=5
7- 4. Benchmark speed & memory
2+ Adaptive Spiking Windows Implementation
3+ Phase 1: Token-wise Temporal Allocation for Spiking Transformers
4+ + vectorized masked attention (einsum)
5+ + unit test for S=4, T=5
6+ + speed/memory benchmarking
87"""
98
10- import time
11- import itertools
129import torch
1310import torch .nn as nn
1411import torch .nn .functional as F
15- from torch .nn .utils .rnn import pad_sequence
12+ import time
13+ import numpy as np
14+ import matplotlib .pyplot as plt
15+ from typing import Tuple , Optional , Dict
1616
17- # -----------------------------------------------------------------------------
18- # LIF Neuron Definition
19- # -----------------------------------------------------------------------------
2017class LIFNeuron (nn .Module ):
18+ """Leaky Integrate-and-Fire neuron with learnable decay"""
2119 def __init__ (self , tau_mem = 20.0 , tau_syn = 5.0 , v_threshold = 1.0 , v_reset = 0.0 ):
2220 super ().__init__ ()
23- self .beta = nn .Parameter (torch .tensor (torch .exp (- 1 / tau_mem )))
24- self .alpha = nn .Parameter (torch .tensor (torch .exp (- 1 / tau_syn )))
21+ self .beta = nn .Parameter (torch .tensor (np .exp (- 1 / tau_mem )))
22+ self .alpha = nn .Parameter (torch .tensor (np .exp (- 1 / tau_syn )))
2523 self .v_threshold = v_threshold
2624 self .v_reset = v_reset
2725
2826 def forward (self , x , state = None ):
27+ # x: [B=1 or B, D], state: (v_mem, i_syn)
2928 if state is None :
3029 v_mem = torch .zeros_like (x )
3130 i_syn = torch .zeros_like (x )
@@ -38,169 +37,152 @@ def forward(self, x, state=None):
3837 return spikes , (v_mem , i_syn )
3938
4039
41- # -----------------------------------------------------------------------------
42- # Adaptive Spiking Attention – Vectorized
43- # -----------------------------------------------------------------------------
4440class AdaptiveSpikingAttention (nn .Module ):
45- def __init__ (self , embedding_dim , num_heads = 8 , T_max = 20 , lambda_reg = 1e-3 , dropout = 0.1 ):
41+ def __init__ (self , embedding_dim , num_heads = 4 , T_max = 20 , lambda_reg = 1e-3 , dropout = 0.1 ):
4642 super ().__init__ ()
4743 assert embedding_dim % num_heads == 0
48- self .embedding_dim = embedding_dim
49- self .num_heads = num_heads
50- self .head_dim = embedding_dim // num_heads
44+ self .D = embedding_dim
45+ self .H = num_heads
46+ self .Dh = embedding_dim // num_heads
5147 self .T_max = T_max
5248 self .lambda_reg = lambda_reg
53- self .scale = self .head_dim ** - 0.5
54-
5549 # projections
56- self .q_proj = nn .Linear (embedding_dim , embedding_dim , bias = False )
57- self .k_proj = nn .Linear (embedding_dim , embedding_dim , bias = False )
58- self .v_proj = nn .Linear (embedding_dim , embedding_dim , bias = False )
59- self .out_proj = nn .Linear (embedding_dim , embedding_dim )
60-
61- # spiking
50+ self .q_proj = nn .Linear (self .D , self .D , bias = False )
51+ self .k_proj = nn .Linear (self .D , self .D , bias = False )
52+ self .v_proj = nn .Linear (self .D , self .D , bias = False )
53+ self .out_proj = nn .Linear (self .D , self .D )
54+ # spiking neurons
6255 self .lif_q = LIFNeuron ()
6356 self .lif_k = LIFNeuron ()
6457 self .lif_v = LIFNeuron ()
65-
6658 # gating
6759 self .window_gate = nn .Sequential (
68- nn .Linear (embedding_dim , 64 ), nn .ReLU (),
69- nn .Linear (64 , 32 ), nn .ReLU (),
70- nn .Linear (32 , 1 ), nn .Sigmoid ()
71- )
72- self .complexity_estimator = nn .Sequential (
73- nn .Linear (embedding_dim , 32 ), nn .ReLU (),
60+ nn .Linear (self .D , 32 ), nn .ReLU (),
7461 nn .Linear (32 , 1 ), nn .Sigmoid ()
7562 )
63+ # dropout & scale
7664 self .dropout = nn .Dropout (dropout )
65+ self .scale = self .Dh ** - 0.5
66+ self .T_history = []
7767
78- def get_adaptive_windows (self , x ):
79- gate = self .window_gate (x ) # [B,S,1]
80- comp = self .complexity_estimator (x ) # [B,S,1]
81- combined = 0.7 * gate + 0.3 * comp
82- T_i = torch .ceil (combined .squeeze (- 1 ) * self .T_max ).clamp (1 , self .T_max ).long ()
83- return T_i # [B,S]
68+ def get_windows (self , x ):
69+ # x: [B, S, D]
70+ gate = self .window_gate (x ).squeeze (- 1 ) # [B, S]
71+ Ti = (gate * self .T_max ).ceil ().clamp (1 , self .T_max ).long ()
72+ return Ti
8473
85- def generate_adaptive_spikes (self , proj , x , T_i ):
74+ def generate_spikes (self , x , Ti , lif ):
75+ # x: [B, S, D] -> spikes: [B, S, T, D]
8676 B , S , D = x .shape
87- spikes = torch . zeros (B , S , self .T_max , D , device = x . device )
77+ out = x . new_zeros (B , S , self .T_max , D )
8878 for b in range (B ):
89- for i in range (S ):
79+ for s in range (S ):
9080 state = None
91- for t in range (T_i [b , i ]):
92- s , state = proj (x [b , i :i + 1 ], state )
93- spikes [b , i , t ] = s
94- return spikes
95-
96- def masked_einsum_attention (self , q_spikes , k_spikes , v_spikes , T_i ):
97- B , S , T , H , Dh = q_spikes .shape
98- # mask: [B,S,T]
99- arange = torch .arange (T , device = T_i .device )
100- mask = (arange [None , None , :] < T_i [:, :, None ]).float ()
101-
81+ for t in range (Ti [b , s ]):
82+ spk , state = lif (x [b :b + 1 , s ], state )
83+ out [b , s , t ] = spk
84+ return out # [B, S, T_max, D]
85+
86+ def vectorized_attention (self , q_spk , k_spk , v_spk , Ti ):
87+ # q_spk,k_spk,v_spk: [B, S, T, D]; reshape -> [B, S, T, H, Dh]
88+ B , S , T , D = q_spk .shape
89+ H , Dh = self .H , self .Dh
90+ q = q_spk .view (B , S , T , H , Dh )
91+ k = k_spk .view (B , S , T , H , Dh )
92+ v = v_spk .view (B , S , T , H , Dh )
93+ # mask: [B, S, T]
94+ mask = (torch .arange (T , device = Ti .device )[None , None , :] < Ti [:, :, None ]).float ()
10295 # apply mask
103- m = mask [:, :, :, None , None ] # [B,S,T,1,1]
104- qm = q_spikes * m
105- km = k_spikes * m
106-
107- # compute raw scores: [B,H,S,S]
108- S_raw = torch .einsum ('bithd,bjthd->bhij' , qm , km )
109- scores = S_raw * self .scale
110- weights = F .softmax (scores , dim = - 1 )
111- weights = self .dropout (weights )
112-
113- # mean-over-time values: [B,S,H,Dh]
114- v_mean = v_spikes .mean (dim = 2 ).view (B , S , H , Dh ).transpose (1 , 2 )
115- out = torch .matmul (weights , v_mean ) # [B,H,S,Dh]
116- out = out .transpose (1 ,2 ).contiguous ().view (B , S , H * Dh )
117- return self .out_proj (out ), weights
118-
119- def compute_reg_loss (self , T_i ):
120- return self .lambda_reg * T_i .float ().mean ()
96+ mask4 = mask [:, :, :, None , None ] # [B,S,T,1,1]
97+ q = q * mask4
98+ k = k * mask4
99+ v = v * mask4
100+ # score: [B,H,S,S]
101+ Sraw = torch .einsum ('bithd,bjthd->bhij' , q , k ) * self .scale
102+ W = F .softmax (Sraw , dim = - 1 )
103+ W = self .dropout (W )
104+ # aggregate v: first mean over time -> [B,S,H,Dh], then attention
105+ v_mean = v .mean (dim = 2 ).transpose (1 , 2 ) # [B,H,S,Dh]
106+ out = torch .einsum ('bhij,bhjd->bhid' , W , v_mean ) # [B,H,S,Dh]
107+ out = out .transpose (1 , 2 ).reshape (B , S , D )
108+ return self .out_proj (out ), W
121109
122110 def forward (self , x ):
123111 B , S , D = x .shape
124- # projections
125- q = self .q_proj (x ).view (B , S , self .num_heads , - 1 )
126- k = self .k_proj (x ).view (B , S , self .num_heads , - 1 )
127- v = self .v_proj (x ).view (B , S , self .num_heads , - 1 )
128-
129- # windows and spikes
130- T_i = self .get_adaptive_windows (x ) # [B,S]
131- q_sp = self .generate_adaptive_spikes (self .lif_q , q , T_i )
132- k_sp = self .generate_adaptive_spikes (self .lif_k , k , T_i )
133- v_sp = self .generate_adaptive_spikes (self .lif_v , v , T_i )
134-
112+ # project
113+ q = self .q_proj (x ); k = self .k_proj (x ); v = self .v_proj (x )
114+ # windows
115+ Ti = self .get_windows (x ) # [B,S]
116+ # spikes
117+ q_spk = self .generate_spikes (q , Ti , self .lif_q )
118+ k_spk = self .generate_spikes (k , Ti , self .lif_k )
119+ v_spk = self .generate_spikes (v , Ti , self .lif_v )
135120 # attention
136- out , attn = self .masked_einsum_attention (q_sp , k_sp , v_sp , T_i )
137- reg = self .compute_reg_loss (T_i )
138- return out , attn , reg , T_i
139-
140-
141- # -----------------------------------------------------------------------------
142- # Unit Test: S=4, T=5
143- # -----------------------------------------------------------------------------
144- def brute_force (q , k , T_i ):
145- B ,S ,T ,H ,D = q .shape
146- S_loop = torch .zeros (B ,H ,S ,S )
121+ out , W = self .vectorized_attention (q_spk , k_spk , v_spk , Ti )
122+ # reg loss
123+ reg = self .lambda_reg * Ti .float ().mean ()
124+ # log
125+ if self .training :
126+ self .T_history .append (Ti .cpu ().numpy ())
127+ return out , {'reg_loss' : reg , 'Ti' : Ti , 'W' : W }
128+
129+ # --- Unit Test & Benchmark --------------------------------------------------
130+
131+ def brute_force (q , k , Ti ):
132+ B , S , T , H , Dh = q .shape
133+ S1 = torch .zeros (B , H , S , S )
147134 for b in range (B ):
148135 for h in range (H ):
149- for i ,j in itertools .product (range (S ),range (S )):
150- tm = min (T_i [b ,i ], T_i [b ,j ])
151- val = 0.
152- for t in range (tm ):
153- val += (q [b ,i ,t ,h ]* k [b ,j ,t ,h ]).sum ()
154- S_loop [b ,h ,i ,j ] = val
155- return S_loop
156-
157- # test
158- B ,S ,T ,H ,D = 1 ,4 ,5 ,2 ,3
159- q = torch .randn (B ,S ,T ,H ,D )
160- k = torch .randn_like (q )
161- T_i = torch .randint (1 , T + 1 , (B ,S ))
162- # brute
163- S1 = brute_force (q ,k ,T_i )
164- # vectorized
165- mask = (torch .arange (T )[None ,None ,:] < T_i [:,:,None ]).float ()
166- qm = q * mask [:,:,:,None ,None ]
167- km = k * mask [:,:,:,None ,None ]
168- S2 = torch .einsum ('bithd,bjthd->bhij' , qm , km )
169- assert torch .allclose (S1 , S2 , atol = 1e-6 ), "Mismatch!"
170- print ("✅ Unit test passed: vectorized == brute force" )
171-
172- # -----------------------------------------------------------------------------
173- # Benchmark Speed & Memory
174- # -----------------------------------------------------------------------------
175- model = AdaptiveSpikingAttention (embedding_dim = 32 , num_heads = 2 , T_max = 5 )
176- x = torch .randn (2 , 10 , 32 )
177-
178- # warm-up
179- for _ in range (10 ):
180- _ = model (x )
181-
182- # benchmark
183- start = time .perf_counter ()
184- for _ in range (50 ):
185- _ = model (x )
186- t_vec = time .perf_counter () - start
187-
188- # brute-force benchmark
189- def bf_forward (x ):
190- # only attention part
191- q = model .q_proj (x ).view (2 ,10 ,2 ,- 1 )
192- k = model .k_proj (x ).view (2 ,10 ,2 ,- 1 )
193- T_i = model .get_adaptive_windows (x )
194- q_sp = model .generate_adaptive_spikes (model .lif_q , q , T_i )
195- k_sp = model .generate_adaptive_spikes (model .lif_k , k , T_i )
196- # brute compute
197- _ = brute_force (q_sp , k_sp , T_i )
198- return _
199-
200- start = time .perf_counter ()
201- for _ in range (50 ):
202- _ = bf_forward (x )
203- t_bf = time .perf_counter () - start
204-
205- print (f"✅ Vectorized forward (50 runs): { t_vec :.3f} s" )
206- print (f"❌ Brute-force (50 runs): { t_bf :.3f} s" )
136+ for i in range (S ):
137+ for j in range (S ):
138+ tlim = min (Ti [b ,i ], Ti [b ,j ]).item ()
139+ val = 0.0
140+ for t in range (tlim ):
141+ val += (q [b ,i ,t ,h ] * k [b ,j ,t ,h ]).sum ()
142+ S1 [b ,h ,i ,j ] = val
143+ return S1
144+
145+ if __name__ == "__main__" :
146+ # test shapes
147+ B ,S ,T ,H ,Dh = 1 ,4 ,5 ,2 ,3
148+ D = H * Dh
149+ model = AdaptiveSpikingAttention (D , num_heads = H , T_max = T )
150+ # fake data
151+ x = torch .randn (B , S , D )
152+ q = torch .randn (B , S , T , H , Dh )
153+ k = torch .randn_like (q )
154+ Ti = torch .randint (1 , T + 1 , (B ,S ))
155+ # brute vs vectorized
156+ bf = brute_force (q , k , Ti )
157+ vec = model .vectorized_attention (q .view (B ,S ,T ,D ), k .view (B ,S ,T ,D ),
158+ torch .randn (B ,S ,T ,D ).view (B ,S ,T ,D ), Ti )[1 ]
159+ # we only compare raw scores before softmax:
160+ # extract raw Sraw from vectorized code manually
161+ # (re-run vectorized_attention but output raw Sraw)
162+ def raw_vec (q_spk ,k_spk ,Ti ):
163+ # q_spk, k_spk: [B, S, T, H, Dh]
164+ B , S , T , H , Dh = q_spk .shape
165+ q_ = q_spk
166+ k_ = k_spk
167+ mask = (torch .arange (T )[None , None , :] < Ti [:, :, None ]).float ()
168+ q_ = q_ * mask [:, :, :, None , None ]
169+ k_ = k_ * mask [:, :, :, None , None ]
170+ return torch .einsum ('bithd,bjthd->bhij' , q_ , k_ )
171+ rv = raw_vec (q , k , Ti )
172+ assert torch .allclose (bf , rv , atol = 1e-5 )
173+ print ("✅ Unit test passed (S=4, T=5)" )
174+
175+ # Benchmark
176+ reps = 100
177+ start = time .perf_counter ()
178+ for _ in range (reps ):
179+ brute_force (q ,k ,Ti )
180+ t1 = time .perf_counter () - start
181+
182+ start = time .perf_counter ()
183+ for _ in range (reps ):
184+ _ = raw_vec (q ,k ,Ti )
185+ t2 = time .perf_counter () - start
186+
187+ print (f"Brute force: { t1 :.4f} s for { reps } runs" )
188+ print (f"Vectorized (raw): { t2 :.4f} s for { reps } runs" )
0 commit comments