Skip to content

Commit 27327a4

Browse files
authored
[example] add palm pytorch version (#2172)
1 parent 12e7bcd commit 27327a4

File tree

7 files changed

+454
-0
lines changed

7 files changed

+454
-0
lines changed

examples/language/palm/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
<img src="./palm.gif" width="450px"></img>
2+
3+
## PaLM - Pytorch
4+
5+
Implementation of the specific Transformer architecture from <a href="https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html">PaLM - Scaling Language Modeling with Pathways</a>, in less than 200 lines of code.
6+
7+
This model is pretty much SOTA on everything language.
8+
9+
It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
10+
11+
## Install
12+
```bash
13+
$ pip install PaLM-pytorch
14+
```
15+
16+
## Usage
17+
18+
```python
19+
import torch
20+
from palm_pytorch import PaLM
21+
22+
palm = PaLM(
23+
num_tokens = 20000,
24+
dim = 512,
25+
depth = 12,
26+
heads = 8,
27+
dim_head = 64,
28+
)
29+
30+
tokens = torch.randint(0, 20000, (1, 2048))
31+
logits = palm(tokens) # (1, 2048, 20000)
32+
```
33+
34+
The PaLM 540B in the paper would be
35+
36+
```python
37+
palm = PaLM(
38+
num_tokens = 256000,
39+
dim = 18432,
40+
depth = 118,
41+
heads = 48,
42+
dim_head = 256
43+
)
44+
```
45+
46+
## Test on Enwik8
47+
48+
```bash
49+
$ python train.py
50+
```
51+
52+
## Todo
53+
54+
- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer
55+
56+
## Citations
57+
58+
```bibtex
59+
@article{chowdhery2022PaLM,
60+
title = {PaLM: Scaling Language Modeling with Pathways},
61+
author = {Chowdhery, Aakanksha et al},
62+
year = {2022}
63+
}
64+
```
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Data source
2+
3+
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from palm_pytorch.palm_pytorch import PaLM
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from einops import rearrange
4+
from torch import nn
5+
6+
# helper function
7+
8+
9+
def exists(val):
10+
return val is not None
11+
12+
13+
def eval_decorator(fn):
14+
15+
def inner(model, *args, **kwargs):
16+
was_training = model.training
17+
model.eval()
18+
out = fn(model, *args, **kwargs)
19+
model.train(was_training)
20+
return out
21+
22+
return inner
23+
24+
25+
# top k filtering
26+
27+
28+
def top_k(logits, thres=0.9):
29+
k = int((1 - thres) * logits.shape[-1])
30+
val, ind = torch.topk(logits, k)
31+
probs = torch.full_like(logits, float("-inf"))
32+
probs.scatter_(1, ind, val)
33+
return probs
34+
35+
36+
class AutoregressiveWrapper(nn.Module):
37+
38+
def __init__(self, net, max_seq_len=2048, pad_value=0):
39+
super().__init__()
40+
self.max_seq_len = max_seq_len
41+
self.pad_value = pad_value
42+
self.net = net
43+
44+
@torch.no_grad()
45+
@eval_decorator
46+
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
47+
b, t, device = *start_tokens.shape, start_tokens.device
48+
49+
out = start_tokens
50+
51+
for _ in range(seq_len):
52+
logits = self.net(out, **kwargs)[:, -1, :]
53+
54+
filtered_logits = top_k(logits, thres=filter_thres)
55+
probs = F.softmax(filtered_logits / temperature, dim=-1)
56+
57+
sample = torch.multinomial(probs, 1)
58+
59+
out = torch.cat((out, sample), dim=-1)
60+
61+
if exists(eos_token):
62+
is_eos_token = out == eos_token
63+
64+
if is_eos_token.any(dim=-1).all():
65+
# mask out everything after the eos tokens
66+
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
67+
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
68+
out = out.masked_fill(mask, self.pad_value)
69+
break
70+
71+
out = out[:, t:]
72+
return out
73+
74+
def forward(self, x, **kwargs):
75+
x_inp, x_labels = x[:, :-1], x[:, 1:]
76+
logits = self.net(x_inp, **kwargs)
77+
return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from einops import rearrange
4+
from torch import einsum, nn
5+
6+
# normalization
7+
# they use layernorm without bias, something that pytorch does not offer
8+
9+
10+
class LayerNorm(nn.Module):
11+
12+
def __init__(self, dim, eps=1e-5):
13+
super().__init__()
14+
self.eps = eps
15+
self.gamma = nn.Parameter(torch.ones(dim))
16+
self.register_buffer("beta", torch.zeros(dim))
17+
18+
def forward(self, x):
19+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
20+
21+
22+
# parallel with residual
23+
# discovered by Wang et al + EleutherAI from GPT-J fame
24+
25+
26+
class ParallelResidual(nn.Module):
27+
28+
def __init__(self, *fns):
29+
super().__init__()
30+
self.fns = nn.ModuleList(fns)
31+
32+
def forward(self, x):
33+
return x + sum([fn(x) for fn in self.fns])
34+
35+
36+
# rotary positional embedding
37+
# https://arxiv.org/abs/2104.09864
38+
39+
40+
class RotaryEmbedding(nn.Module):
41+
42+
def __init__(self, dim):
43+
super().__init__()
44+
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
45+
self.register_buffer("inv_freq", inv_freq)
46+
47+
def forward(self, max_seq_len, *, device):
48+
seq = torch.arange(max_seq_len, device=device)
49+
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
50+
return torch.cat((freqs, freqs), dim=-1)
51+
52+
53+
def rotate_half(x):
54+
x = rearrange(x, "... (j d) -> ... j d", j=2)
55+
x1, x2 = x.unbind(dim=-2)
56+
return torch.cat((-x2, x1), dim=-1)
57+
58+
59+
def apply_rotary_pos_emb(pos, t):
60+
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
61+
62+
63+
# feedforward
64+
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
65+
# https://arxiv.org/abs/2002.05202
66+
67+
68+
class SwiGLU(nn.Module):
69+
70+
def forward(self, x):
71+
x, gate = x.chunk(2, dim=-1)
72+
return F.silu(gate) * x
73+
74+
75+
def FeedForward(dim, mult=4):
76+
inner_dim = int(dim * mult)
77+
return nn.Sequential(
78+
LayerNorm(dim),
79+
nn.Linear(dim, inner_dim * 2, bias=False),
80+
SwiGLU(),
81+
nn.Linear(inner_dim, dim, bias=False),
82+
)
83+
84+
85+
# attention
86+
87+
88+
class Attention(nn.Module):
89+
90+
def __init__(self, dim, dim_head=64, heads=8):
91+
super().__init__()
92+
inner_dim = dim_head * heads
93+
self.norm = LayerNorm(dim)
94+
self.heads = heads
95+
self.scale = dim_head**-0.5
96+
self.rotary_emb = RotaryEmbedding(dim_head)
97+
98+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
99+
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
100+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
101+
102+
# for caching causal mask and rotary embeddings
103+
104+
self.register_buffer("mask", None, persistent=False)
105+
self.register_buffer("pos_emb", None, persistent=False)
106+
107+
def get_mask(self, n, device):
108+
if self.mask is not None and self.mask.shape[-1] >= n:
109+
return self.mask[:n, :n]
110+
111+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
112+
self.register_buffer("mask", mask, persistent=False)
113+
return mask
114+
115+
def get_rotary_embedding(self, n, device):
116+
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
117+
return self.pos_emb[:n]
118+
119+
pos_emb = self.rotary_emb(n, device=device)
120+
self.register_buffer("position", pos_emb, persistent=False)
121+
return pos_emb
122+
123+
def forward(self, x):
124+
"""
125+
einstein notation
126+
b - batch
127+
h - heads
128+
n, i, j - sequence length (base sequence length, source, target)
129+
d - feature dimension
130+
"""
131+
132+
n, device, h = x.shape[1], x.device, self.heads
133+
134+
# pre layernorm
135+
136+
x = self.norm(x)
137+
138+
# queries, keys, values
139+
140+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
141+
142+
# split heads
143+
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
144+
# they found no performance loss past a certain scale, and more efficient decoding obviously
145+
# https://arxiv.org/abs/1911.02150
146+
147+
q = rearrange(q, "b n (h d) -> b h n d", h=h)
148+
149+
# rotary embeddings
150+
151+
positions = self.get_rotary_embedding(n, device)
152+
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
153+
154+
# scale
155+
156+
q = q * self.scale
157+
158+
# similarity
159+
160+
sim = einsum("b h i d, b j d -> b h i j", q, k)
161+
162+
# causal mask
163+
164+
causal_mask = self.get_mask(n, device)
165+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
166+
167+
# attention
168+
169+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
170+
attn = sim.softmax(dim=-1)
171+
172+
# aggregate values
173+
174+
out = einsum("b h i j, b j d -> b h i d", attn, v)
175+
176+
# merge heads
177+
178+
out = rearrange(out, "b h n d -> b n (h d)")
179+
return self.to_out(out)
180+
181+
182+
# transformer
183+
184+
185+
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
186+
net = nn.Sequential(
187+
nn.Embedding(num_tokens, dim), *[
188+
ParallelResidual(
189+
Attention(dim=dim, dim_head=dim_head, heads=heads),
190+
FeedForward(dim=dim, mult=ff_mult),
191+
) for _ in range(depth)
192+
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
193+
194+
# they used embedding weight tied projection out to logits, not common, but works
195+
net[-1].weight = net[0].weight
196+
197+
nn.init.normal_(net[0].weight, std=0.02)
198+
return net

0 commit comments

Comments
 (0)