-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer_block.py
More file actions
112 lines (94 loc) · 3.36 KB
/
transformer_block.py
File metadata and controls
112 lines (94 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
from dataclasses import dataclass
from typing import Optional, Callable, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from multihead_self_attention import MultiHeadSelfAttention
class FeedForward(nn.Module):
def __init__(self, d_model: int, hidden_dim: int, activation: str = 'gelu', dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, d_model)
self.dropout = nn.Dropout(dropout)
act_map = {
'gelu': F.gelu,
'relu': F.relu,
'silu': F.silu,
# Custom activation functions can be added here
}
if activation not in act_map:
raise ValueError(f"unsupported activation: {activation}, add custom act function in transformer_block.py")
self.activation_fn = act_map[activation]
# Init
nn.init.normal_(self.fc1.weight, mean=0.0, std=0.02)
nn.init.normal_(self.fc2.weight, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
@dataclass
class TransformerBlockConfig:
d_model: int
n_heads: int
block_size: int
ffn_hidden_mult: float = 4.0 # hidden dim = ffn_hidden_mult * d_model
ffn_activation: str = 'gelu'
attn_dropout: float = 0.1
resid_dropout: float = 0.1
layer_norm_eps: float = 1e-5
class TransformerBlock(nn.Module):
def __init__(self, cfg: TransformerBlockConfig):
super().__init__()
self.cfg = cfg
# Layernorms
self.ln1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.ln2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
# Attention
self.attn = MultiHeadSelfAttention(
d_model=cfg.d_model,
n_heads=cfg.n_heads,
block_size=cfg.block_size,
dropout=cfg.attn_dropout
)
# FFN
hidden_dim = int(cfg.ffn_hidden_mult*cfg.d_model)
self.ffn = FeedForward(cfg.d_model, hidden_dim, cfg.ffn_activation, cfg.resid_dropout)
def forward(self, x: torch.Tensor, return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x_norm = self.ln1(x)
att_out, att_map = self.attn(x_norm, return_attn=return_attn)
x = x+att_out # residual
# Pre-LN -> FFN -> Residual
x_norm = self.ln2(x)
ffn_out = self.ffn(x_norm)
x = x + ffn_out
if return_attn:
return x, att_map
return x, None
# Quick unit test
if __name__ == "__main__":
torch.manual_seed(0)
B, T, d_model = 2, 16, 64
cfg = TransformerBlockConfig(
d_model=d_model,
n_heads=8,
block_size=32,
ffn_hidden_mult=4.0,
ffn_activation="gelu",
attn_dropout=0.0,
resid_dropout=0.0
)
block = TransformerBlock(cfg)
x = torch.randn(B, T, d_model, requires_grad=True)
out, att = block(x, return_attn=True)
print("out.shape:", out.shape) # (2, 16, 64)
if att is not None:
print("att.shape:", att.shape) # (2, 8, 16, 16)
# gradient check
loss = out.sum()
loss.backward()
grads_ok = (x.grad is not None)
print("grads on input present:", grads_ok)