-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpositional_embeddings.py
More file actions
67 lines (54 loc) · 2.33 KB
/
positional_embeddings.py
File metadata and controls
67 lines (54 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import equinox as eqx
import jax
from jax import numpy as jnp
# From [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py#L306)
class Rotary(eqx.Module):
cos: jax.Array
sin: jax.Array
dim: int = eqx.field(static=True)
max_seq_len: int = eqx.field(static=True)
def __init__(self, dim: int, max_seq_len: int):
self.dim = dim
self.max_seq_len = max_seq_len
# Step 1: Create angular frequencies
# (1 / 1024) ** linspace(0, 1, dim//4)
angular_freq = (1 / 1024) ** jnp.linspace(0, 1, dim // 4, dtype=jnp.float32)
# Step 2: Concatenate with zeros: [angular_freq, zeros]
angular_freq = jnp.concatenate([angular_freq, jnp.zeros((dim // 4,), dtype=jnp.float32)]) # shape: [dim//2]
# Step 3: Outer product: t[i] * angular_freq[j]
t = jnp.arange(max_seq_len, dtype=jnp.float32)
theta = jnp.outer(t, angular_freq) # [max_seq_len, dim//2]
# Step 4: Precompute cos and sin
self.cos = jnp.cos(theta) # [max_seq_len, dim//2]
self.sin = jnp.sin(theta) # [max_seq_len, dim//2]
def __call__(self, x_BTHD: jax.Array, reverse=False) -> jax.Array:
"""
Apply rotary embedding.
Args:
x_BTHD: tensor of shape [Batch, Time, Heads, Dim]
Returns:
Rotated tensor of same shape
"""
current_seq_len = x_BTHD.shape[1]
assert self.cos.shape[0] >= current_seq_len, (
f"Need {current_seq_len} positions, but only have {self.cos.shape[0]} precomputed."
)
# Extract cos/sin for current sequence length
cos = self.cos[:current_seq_len] # [T, dim//2]
sin = self.sin[:current_seq_len] # [T, dim//2]
# Add broadcast axes: [1, T, 1, dim//2]
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
# Split input on last dimension
x_float = x_BTHD.astype(jnp.float32)
x1, x2 = jnp.split(x_float, 2, axis=-1) # Each: [B, T, H, dim//2]
if not reverse:
# Apply rotation
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
else:
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
# Recombine
out = jnp.concatenate([y1, y2], axis=-1) # [B, T, H, dim]
return out.astype(x_BTHD.dtype)