-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultihead_self_attention.py
More file actions
85 lines (66 loc) · 2.71 KB
/
multihead_self_attention.py
File metadata and controls
85 lines (66 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, block_size: int, dropout: float = 0.1):
super().__init__()
assert d_model%n_heads==0, (
f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
)
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.block_size = block_size
# One fused linear layer for Q, K, V projections
self.W_qkv = nn.Linear(d_model, 3*d_model, bias=False)
self.W_out = nn.Linear(d_model, d_model, bias=False)
# Dropouts
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Register a causal mask as a persistent buffer
mask = torch.tril(torch.ones(block_size, block_size))
self.register_buffer("mask", mask.view(1, 1, block_size, block_size))
def forward(self, x, return_attn=False):
"""
x: (B, T, d_model)
return_attn: if True, return attention weights (useful for visualization)
return (B, T, d_model)
"""
B, T, C = x.shape
H, D = self.n_heads, self.head_dim
assert T<=self.block_size, (
f"Sequence length {T} exceeds configured block_size {self.block_size}"
)
# Compute Q, K, V via one projection
qkv = self.W_qkv(x)
q, k, v = qkv.split(C, dim=2) # Each (B, T, d_model)
# Reshape for multi-head: (B, T, d_model) -> (B, H, T, head_dim)
q = q.view(B, T, H, D).transpose(1, 2)
k = k.view(B, T, H, D).transpose(1, 2)
v = v.view(B, T, H, D).transpose(1, 2)
# Compute attention scores: (B, H, T, T)
att = (q @ k.transpose(-2, -1)) / math.sqrt(D)
# Apply causal mask
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
# Softmax over keys
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# Weighted sum of values
y = att @ v
# Merge heads: (B, H, T, head_dim) -> (B, T, d_model)
y = y.transpose(1, 2).contiguous().view(B, T, C)
# Output projection + residual dropout
y = self.W_out(y)
y = self.resid_dropout(y)
if return_attn:
return y, att
return y, None
if __name__ == "__main__":
torch.manual_seed(42)
B, T, C = 2, 8, 64
mha = MultiHeadSelfAttention(d_model=C, n_heads=4, block_size=16, dropout=0.1)
x = torch.randn(B, T, C)
out, att = mha(x, return_attn=True)
print("Output shape:", out.shape) # (2, 8, 64)
print("Attention shape:", att.shape) # (2, 4, 8, 8)