Skip to content

Commit b0cd66e

Browse files
Create transformer_blocks.py
1 parent c8a0084 commit b0cd66e

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/models/transformer_blocks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
from ...models.adaptive_attention import AdaptiveSpikingAttention
4+
5+
class SpikingTransformerBlock(nn.Module):
6+
def __init__(self, embedding_dim, num_heads, T_max=20, lambda_reg=1e-3):
7+
super().__init__()
8+
9+
# 🔥 Replace standard attention with adaptive spiking attention
10+
self.attention = AdaptiveSpikingAttention(
11+
embedding_dim=embedding_dim,
12+
num_heads=num_heads,
13+
T_max=T_max,
14+
lambda_reg=lambda_reg
15+
)
16+
17+
self.norm1 = nn.LayerNorm(embedding_dim)
18+
self.norm2 = nn.LayerNorm(embedding_dim)
19+
20+
self.ffn = nn.Sequential(
21+
nn.Linear(embedding_dim, 4 * embedding_dim),
22+
nn.GELU(),
23+
nn.Linear(4 * embedding_dim, embedding_dim),
24+
nn.Dropout(0.1)
25+
)
26+
27+
def forward(self, x, mask=None):
28+
# Adaptive spiking attention
29+
attn_out, metrics = self.attention(self.norm1(x), mask)
30+
x = x + attn_out
31+
32+
# Feed-forward network
33+
ffn_out = self.ffn(self.norm2(x))
34+
x = x + ffn_out
35+
36+
return x, metrics

0 commit comments

Comments
 (0)