Skip to content

Commit 82d1a79

Browse files
authored
Grug moe replicated weights (#3064)
1 parent 3e37677 commit 82d1a79

File tree

1 file changed

+388
-0
lines changed

1 file changed

+388
-0
lines changed
Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
# Copyright 2025 The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import dataclasses
5+
from dataclasses import dataclass
6+
from functools import partial
7+
8+
import equinox as eqx
9+
import jax
10+
import jax.numpy as jnp
11+
import jax.scipy as jsp
12+
import levanter.tracker
13+
from einops import rearrange
14+
from haliax.jax_utils import named_call
15+
from haliax.partitioning import _get_mesh
16+
from jax import random
17+
from jax.experimental.shard_map import shard_map
18+
from jax.sharding import PartitionSpec as P
19+
from jaxtyping import Array, Float, Int, PRNGKeyArray
20+
21+
from .attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention
22+
from .loss import fused_linear_softmax_cross_entropy_loss
23+
from .sharding import Pbatch, unshard
24+
25+
26+
#### Conventions
27+
28+
# Mesh meanings:
29+
# - "data": data parallel sharding axis.
30+
# All model weights (including expert weights) are fully replicated across chips.
31+
32+
# Dim names:
33+
# - B = batch
34+
# - D = embedding / hidden dim
35+
# - S = sequence length
36+
# - N = num heads
37+
# - M = num kv heads
38+
# - H = head dim
39+
# - I = intermediate dim
40+
# - T = tokens (B * S, flattened batch)
41+
# - K = num_experts_per_tok
42+
# - TR = T * K (tokens repeated per expert, sorted by expert)
43+
# - E = n_routed_experts
44+
45+
46+
@dataclass(frozen=True)
47+
class GrugModelConfig:
48+
"""Hyperparameters for the Grug Mixtral MoE style transformer."""
49+
50+
vocab_size: int
51+
hidden_dim: int = 1536
52+
intermediate_dim: int = 4608
53+
num_layers: int = 12
54+
num_heads: int = 12
55+
num_kv_heads: int = 12
56+
head_dim: int | None = None
57+
max_seq_len: int = 2048
58+
layer_norm_eps: float = 1e-5
59+
initializer_std: float = 0.02
60+
61+
num_experts_per_tok: int = 2
62+
n_routed_experts: int = 8
63+
64+
lbl_coef: float | None = 0.01
65+
rzl_coef: float | None = 0.001
66+
67+
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
68+
69+
def __post_init__(self) -> None:
70+
_ = self.inferred_head_dim
71+
if self.num_heads % self.num_kv_heads != 0:
72+
raise ValueError("num_heads must be divisible by num_kv_heads for grouped-query attention")
73+
if self.vocab_size <= 0:
74+
raise ValueError("vocab_size must be positive")
75+
if self.max_seq_len <= 0:
76+
raise ValueError("max_seq_len must be positive")
77+
78+
@property
79+
def inferred_head_dim(self) -> int:
80+
if self.head_dim is not None:
81+
return self.head_dim
82+
if self.hidden_dim % self.num_heads != 0:
83+
raise ValueError(
84+
f"hidden_dim={self.hidden_dim} is not divisible by num_heads={self.num_heads}; set head_dim explicitly"
85+
)
86+
return self.hidden_dim // self.num_heads
87+
88+
89+
class CausalSelfAttention(eqx.Module):
90+
w_q: jax.Array
91+
w_k: jax.Array
92+
w_v: jax.Array
93+
w_o: jax.Array
94+
cfg: GrugModelConfig = eqx.field(static=True)
95+
96+
@staticmethod
97+
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "CausalSelfAttention":
98+
k_q, k_k, k_v, k_o = random.split(key, 4)
99+
D, N, M, H = cfg.hidden_dim, cfg.num_heads, cfg.num_kv_heads, cfg.inferred_head_dim
100+
return CausalSelfAttention(
101+
w_q=_init_weight(k_q, (D, N * H), cfg.initializer_std),
102+
w_k=_init_weight(k_k, (D, M * H), cfg.initializer_std),
103+
w_v=_init_weight(k_v, (D, M * H), cfg.initializer_std),
104+
w_o=_init_weight(k_o, (N * H, D), cfg.initializer_std),
105+
cfg=cfg,
106+
)
107+
108+
@named_call
109+
def __call__(self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array) -> Float[Array, "B S D"]:
110+
head_dim = self.cfg.inferred_head_dim
111+
seq_len = x.shape[1]
112+
113+
q = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_q), "... (n d) -> ... n d", d=head_dim)
114+
k = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_k), "... (m d) -> ... m d", d=head_dim)
115+
v = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_v), "... (m d) -> ... m d", d=head_dim)
116+
q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=self.cfg.rope)
117+
attn_out = attention(q, k, v, mask)
118+
attn_out = rearrange(attn_out, "... n d -> ... (n d)")
119+
return jnp.einsum("bsh,hd->bsd", attn_out, self.w_o, out_sharding=Pbatch)
120+
121+
122+
class MOE(eqx.Module):
123+
router_w: jax.Array
124+
w1: jax.Array
125+
w2: jax.Array
126+
w3: jax.Array
127+
cfg: GrugModelConfig = eqx.field(static=True)
128+
129+
_ragged_dim_numbers = jax.lax.RaggedDotDimensionNumbers(
130+
dot_dimension_numbers=(((1,), (1,)), ((), ())),
131+
lhs_ragged_dimensions=(0,),
132+
rhs_group_dimensions=(0,),
133+
)
134+
135+
@staticmethod
136+
def _ragged_linear(x: jax.Array, w: jax.Array, group_sizes: jax.Array) -> jax.Array:
137+
"""Ragged MoE linear: (TR, In) x (E, In, Out) with groups along TR."""
138+
return jax.lax.ragged_dot_general(
139+
lhs=x,
140+
rhs=w,
141+
group_sizes=group_sizes,
142+
ragged_dot_dimension_numbers=MOE._ragged_dim_numbers,
143+
)
144+
145+
@staticmethod
146+
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MOE":
147+
k_router_w, k_w1, k_w2, k_w3 = random.split(key, 4)
148+
E, D, I = cfg.n_routed_experts, cfg.hidden_dim, cfg.intermediate_dim
149+
router_w = _init_weight(k_router_w, (D, E), cfg.initializer_std)
150+
w1 = _init_weight(k_w1, (E, D, I), cfg.initializer_std)
151+
w2 = _init_weight(k_w2, (E, D, I), cfg.initializer_std)
152+
w3 = _init_weight(k_w3, (E, I, D), cfg.initializer_std)
153+
return MOE(router_w, w1, w2, w3, cfg)
154+
155+
@named_call
156+
def __call__(self, x: Float[Array, "B S D"]) -> tuple[Float[Array, "B S D"], dict]:
157+
B, S, D = x.shape
158+
x_flat = jnp.reshape(x, (B * S, D))
159+
router_logits = jnp.einsum("td,de->te", x_flat, self.router_w)
160+
topk_weights, topk_idx, router_probs = self._route(router_logits)
161+
topk_idx_flat = jnp.reshape(topk_idx, (B * S * self.cfg.num_experts_per_tok,))
162+
mesh = _get_mesh()
163+
164+
@partial(
165+
shard_map,
166+
mesh=mesh,
167+
in_specs=(Pbatch, Pbatch, Pbatch, P(), P(), P()),
168+
out_specs=(Pbatch, P()),
169+
)
170+
def _moe_block(x_flat, topk_idx_flat, topk_weights, w1, w2, w3):
171+
x_repeat_sort, group_sizes, sort_idx = self._permute(x_flat, topk_idx_flat)
172+
w1_out = MOE._ragged_linear(x_repeat_sort, w1, group_sizes) # [TR, I]
173+
w2_out = MOE._ragged_linear(x_repeat_sort, w2, group_sizes) # [TR, I]
174+
gated = jax.nn.silu(w1_out) * w2_out # [TR, I]
175+
out_repeat_sort = MOE._ragged_linear(gated, w3, group_sizes) # [TR, D]
176+
out_repeat_unflat = self._unpermute(out_repeat_sort, sort_idx)
177+
out_flat = jnp.sum(out_repeat_unflat * topk_weights[..., None], axis=1) # [T, D]
178+
179+
# compute statistics and aux loss over global batch
180+
global_group_sizes = jax.lax.psum(group_sizes, "data")
181+
return out_flat, global_group_sizes
182+
183+
out_flat, group_sizes = _moe_block(x_flat, topk_idx_flat, topk_weights, self.w1, self.w2, self.w3)
184+
out = jnp.reshape(out_flat, (B, S, D))
185+
186+
extras = {}
187+
if self.cfg.lbl_coef is not None:
188+
group_sizes_f = group_sizes.astype(jnp.float32)
189+
expert_loads = group_sizes_f / jnp.sum(group_sizes_f)
190+
extras["expert_loads"] = expert_loads
191+
f = expert_loads * (self.cfg.n_routed_experts / self.cfg.num_experts_per_tok)
192+
p = jnp.mean(router_probs.astype(jnp.float32), axis=0) # [T, E] -> [E]
193+
extras["load_balancing_loss"] = jnp.asarray(self.cfg.lbl_coef, dtype=jnp.float32) * jnp.sum(f * p)
194+
195+
if self.cfg.rzl_coef is not None:
196+
z = jsp.special.logsumexp(router_logits.astype(jnp.float32), axis=-1)
197+
extras["router_z_loss"] = jnp.asarray(self.cfg.rzl_coef, dtype=jnp.float32) * jnp.mean(z**2)
198+
199+
return out, extras
200+
201+
def _route(
202+
self, router_logits: Float[Array, "T E"]
203+
) -> tuple[Float[Array, "T K"], Int[Array, "T K"], Float[Array, "T E"]]:
204+
"""Select top-k experts per token and compute normalized routing weights."""
205+
router_probs = jax.nn.softmax(router_logits, axis=-1)
206+
_scores, topk_idx = jax.lax.top_k(router_logits, self.cfg.num_experts_per_tok)
207+
topk_weights = jnp.take_along_axis(router_probs, topk_idx, axis=-1)
208+
topk_weights = topk_weights / jnp.sum(topk_weights, axis=-1, keepdims=True)
209+
return topk_weights, topk_idx.astype(jnp.int32), router_probs
210+
211+
def _permute(
212+
self, x_flat: jax.Array, topk_idx_flat: jax.Array
213+
) -> tuple[Float[Array, "TR D"], Int[Array, "E"], Int[Array, "TR"]]:
214+
"""Sort tokens by assigned expert and compute per-expert group sizes for ragged_dot."""
215+
sort_idx = jnp.argsort(topk_idx_flat, axis=-1)
216+
x_repeat_sort = jnp.take(x_flat, sort_idx // self.cfg.num_experts_per_tok, axis=0)
217+
group_sizes = jnp.bincount(topk_idx_flat, length=self.cfg.n_routed_experts).astype(jnp.int32)
218+
return x_repeat_sort, group_sizes, sort_idx.astype(jnp.int32)
219+
220+
def _unpermute(self, out_repeat_sort: jax.Array, sort_idx: jax.Array) -> Float[Array, "T K D"]:
221+
"""Reverse the expert-sorted order back to the original token layout."""
222+
inv_sort_idx = jnp.argsort(sort_idx, axis=-1)
223+
out_repeat = jnp.take(out_repeat_sort, inv_sort_idx, axis=0)
224+
return jnp.reshape(out_repeat, (-1, self.cfg.num_experts_per_tok, self.cfg.hidden_dim))
225+
226+
227+
class RMSNorm(eqx.Module):
228+
weight: jax.Array
229+
eps: float = eqx.field(static=True)
230+
231+
@staticmethod
232+
def init(dim: int, eps: float) -> "RMSNorm":
233+
return RMSNorm(weight=jnp.ones((dim,), dtype=jnp.float32), eps=eps)
234+
235+
@named_call
236+
def __call__(self, x: Float[Array, "... D"]) -> Float[Array, "... D"]:
237+
weight = unshard(self.weight)
238+
dtype = x.dtype
239+
x = x.astype(jnp.float32)
240+
variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
241+
normed = x * jax.lax.rsqrt(variance + self.eps)
242+
return (normed * weight).astype(dtype)
243+
244+
245+
class Block(eqx.Module):
246+
rms_attn: RMSNorm
247+
attn: CausalSelfAttention
248+
rms_mlp: RMSNorm
249+
moe: MOE
250+
251+
@staticmethod
252+
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Block":
253+
attn_key, moe_key = random.split(key, 2)
254+
return Block(
255+
rms_attn=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps),
256+
attn=CausalSelfAttention.init(cfg, key=attn_key),
257+
rms_mlp=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps),
258+
moe=MOE.init(cfg, key=moe_key),
259+
)
260+
261+
@named_call
262+
def __call__(
263+
self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array
264+
) -> tuple[Float[Array, "B S D"], dict]:
265+
x = x + self.attn(self.rms_attn(x), mask)
266+
moe_out, extras = self.moe(self.rms_mlp(x))
267+
x = x + moe_out
268+
return x, extras
269+
270+
271+
class Transformer(eqx.Module):
272+
token_embed: jax.Array
273+
output_proj: jax.Array
274+
blocks: tuple[Block, ...]
275+
final_norm: RMSNorm
276+
config: GrugModelConfig = eqx.field(static=True)
277+
278+
@staticmethod
279+
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Transformer":
280+
embed_key, out_key, *block_keys = random.split(key, cfg.num_layers + 2)
281+
token_embed = _init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std)
282+
output_proj = _init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std)
283+
blocks = tuple(Block.init(cfg, key=layer_key) for layer_key in block_keys)
284+
final_norm = RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps)
285+
return Transformer(
286+
token_embed=token_embed,
287+
output_proj=output_proj,
288+
blocks=blocks,
289+
final_norm=final_norm,
290+
config=cfg,
291+
)
292+
293+
@named_call
294+
def __call__(
295+
self,
296+
token_ids: Int[Array, "B S"],
297+
mask: AttentionMask | jax.Array | None = None,
298+
) -> Float[Array, "B S D"]:
299+
if mask is None:
300+
mask = AttentionMask.causal()
301+
302+
hidden = self.token_embed.at[token_ids].get(out_sharding=Pbatch)
303+
all_extras = []
304+
for block in self.blocks:
305+
hidden, extras = eqx.filter_checkpoint(block)(hidden, mask)
306+
all_extras.append(extras)
307+
aux_loss = self.parse_aux_loss(all_extras)
308+
return self.final_norm(hidden), aux_loss
309+
310+
@named_call
311+
def logits(
312+
self,
313+
token_ids: Int[Array, "B S"],
314+
mask: AttentionMask | jax.Array | None = None,
315+
) -> Float[Array, "B S V"]:
316+
hidden, _ = self(token_ids, mask=mask)
317+
return jnp.einsum("bsh,hd->bsd", hidden, self.output_proj, out_sharding=Pbatch)
318+
319+
def next_token_loss(
320+
self,
321+
token_ids: Int[Array, "B S"],
322+
loss_weight: Float[Array, "B S"],
323+
*,
324+
mask: AttentionMask | jax.Array | None = None,
325+
reduction: str = "mean",
326+
logsumexp_weight: float | None = None,
327+
loss_dtype: jnp.dtype = jnp.float32,
328+
) -> jax.Array:
329+
"""Compute next-token cross-entropy loss for a batch."""
330+
hidden, aux_loss = self(token_ids, mask=mask)
331+
labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32)
332+
loss_weight = loss_weight.astype(loss_dtype)
333+
334+
return (
335+
fused_linear_softmax_cross_entropy_loss(
336+
hidden,
337+
self.output_proj,
338+
labels,
339+
weight=loss_weight,
340+
reduction=reduction,
341+
logsumexp_weight=logsumexp_weight,
342+
dtype=loss_dtype,
343+
)
344+
+ aux_loss
345+
)
346+
347+
def parse_aux_loss(self, all_extras) -> Float[Array, ""]:
348+
load_balancing_loss = 0
349+
router_z_loss = 0
350+
stats = {}
351+
for i, extras in enumerate(all_extras):
352+
if "load_balancing_loss" in extras:
353+
stats[f"train/layer_{i}/load_balancing_loss"] = jax.lax.stop_gradient(extras["load_balancing_loss"])
354+
load_balancing_loss += extras["load_balancing_loss"]
355+
if "router_z_loss" in extras:
356+
stats[f"train/layer_{i}/router_z_loss"] = jax.lax.stop_gradient(extras["router_z_loss"])
357+
router_z_loss += extras["router_z_loss"]
358+
if "expert_loads" in extras:
359+
expert_loads = extras["expert_loads"] # [E], sums to 1
360+
n_experts = self.config.n_routed_experts
361+
362+
entropy = -jnp.sum(expert_loads * jnp.log(expert_loads + 1e-6))
363+
load_violation_max = jnp.max(expert_loads) * n_experts
364+
365+
stats[f"train/layer_{i}/routing_entropy"] = jax.lax.stop_gradient(entropy)
366+
stats[f"train/layer_{i}/load_violation_max"] = jax.lax.stop_gradient(load_violation_max)
367+
for j in range(n_experts):
368+
stats[f"train/layer_{i}/expert_{j}/load"] = jax.lax.stop_gradient(expert_loads[j])
369+
370+
stats["train/load_balancing_loss"] = jax.lax.stop_gradient(load_balancing_loss)
371+
stats["train/router_z_loss"] = jax.lax.stop_gradient(router_z_loss)
372+
levanter.tracker.jit_log(stats)
373+
aux_loss = load_balancing_loss + router_z_loss
374+
return aux_loss
375+
376+
377+
def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]:
378+
return std * random.truncated_normal(key, -3, 3, shape)
379+
380+
381+
__all__ = [
382+
"CausalSelfAttention",
383+
"MOE",
384+
"RMSNorm",
385+
"Block",
386+
"Transformer",
387+
"GrugModelConfig",
388+
]

0 commit comments

Comments
 (0)