diff --git a/.gitignore b/.gitignore index b8620580ae..75c7aef70e 100644 --- a/.gitignore +++ b/.gitignore @@ -229,5 +229,6 @@ scr/* /scratch .claude +.agents/tmp/ .codex .entire diff --git a/docs/recipes/change_grug.md b/docs/recipes/change_grug.md index 0d830e1283..0ed6cf84a2 100644 --- a/docs/recipes/change_grug.md +++ b/docs/recipes/change_grug.md @@ -73,7 +73,7 @@ Run the relevant checks: ```bash uv run python infra/pre-commit.py --all-files -uv run pytest tests/test_grug_base_template.py +uv run pytest tests/test_grug_variant_contracts.py ``` Add any additional focused tests needed for behavior changes. diff --git a/docs/recipes/github-pr-review.md b/docs/recipes/github-pr-review.md index 46f39f7f54..601461e248 100644 --- a/docs/recipes/github-pr-review.md +++ b/docs/recipes/github-pr-review.md @@ -23,6 +23,11 @@ Scope: - If a specification exists (issue description, design doc, acceptance criteria, or inline requirements), verify the code adheres to it and flag concrete mismatches. - Ignore formatting, import order, lint/style preferences, naming opinions, missing docstrings/comments, and generic best-practice advice. +Grug Variants: +- In `experiments/grug`, duplication is often intentional. This area is designed for high-velocity, short-lived research iteration. +- Do not flag copy/paste or DRY concerns by default in `experiments/grug` if behavior/contracts are correct. +- Only call out duplication there when it causes a concrete correctness issue, regression risk, or clear divergence from stated objectives. + Output contract: - Return exactly one final review. - Keep output compact and high-signal. diff --git a/experiments/grug/README.md b/experiments/grug/README.md index 71d214cc7c..ae1de1e00b 100644 --- a/experiments/grug/README.md +++ b/experiments/grug/README.md @@ -11,10 +11,10 @@ ## Entry-point guide - Start in `base/launch.py` for normal run edits. -- `GrugBaseLaunchConfig` is the user-facing knob surface (model/data/optimizer/trainer/eval/run metadata). +- Each variant `/launch.py` exposes its own `*LaunchConfig` as the user-facing knob surface. - `versioned(...)` marks config values that should affect executor step version/hash. - `this_output_path()` resolves to the current step's output root. -- `run_grug(...)` in `base/train.py` is the runtime entry point used by the `ExecutorStep`. +- `run_grug(...)` in each variant's `train.py` is the runtime entry point used by the `ExecutorStep`. - `P` in train/model code is the usual JAX alias for `PartitionSpec`; see the JAX explicit sharding tutorial: [Explicit Sharding (JAX)](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html). ## How to use it @@ -106,6 +106,29 @@ uv run lib/marin/src/marin/run/ray_run.py \ - Keep core training/eval metrics aligned with classic Levanter (`train/loss`, `throughput/*`, `eval/*`). - Prefer shared helpers only for generic infrastructure; keep variant behavior local to the template. +## Variant contract (enforced by tests) + +`tests/test_grug_variant_contracts.py` treats each subdirectory under `experiments/grug/` as a variant and +enforces these minimum interfaces: + +- If `/model.py` exists, it must define: + - `GrugModelConfig` constructable as `GrugModelConfig(vocab_size=...)` + - `Transformer` with `next_token_loss(...)` + - `debug_mesh_and_token_pspec(num_devices: int)` +- If `/train.py` exists, it must define: + - `initial_state(model_config, *, optimizer, mp, key)` (all required) + - `_make_train_step(...)` + - `run_grug(...)` +- If both `model.py` and `train.py` exist, the variant must lower a one-step train path under abstract mesh via + `eqx.filter_eval_shape`. +- Escape hatch: add `# GRUG NOVERIFY` anywhere in `/train.py` to exclude that variant from these contract + checks. + +## Current variants + +- `base`: `experiments/grug/base/` +- `moe`: `experiments/grug/moe/` + ## Further guidance - Grug principles: [`/.agents/projects/grugformer.md`](../../.agents/projects/grugformer.md) diff --git a/experiments/grug/base/model.py b/experiments/grug/base/model.py index aea95c4442..2d3ef04ed0 100644 --- a/experiments/grug/base/model.py +++ b/experiments/grug/base/model.py @@ -194,7 +194,7 @@ def logits( hidden = self(token_ids, mask=mask) return jnp.einsum("bsh,hd->bsd", hidden, self.output_proj, out_sharding=Pbatch) - def compute_next_token_loss( + def next_token_loss( self, token_ids: Int[Array, "B S"], loss_weight: Float[Array, "B S"], @@ -224,6 +224,21 @@ def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float return std * random.truncated_normal(key, -3, 3, shape) +def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractMesh, P]: + """Return a small abstract mesh and token sharding for lowering contract tests.""" + if num_devices <= 0: + raise ValueError(f"num_devices must be positive, got {num_devices}") + mesh = jax.sharding.AbstractMesh( + axis_sizes=(num_devices, 1), + axis_names=("data", "model"), + axis_types=( + jax.sharding.AxisType.Explicit, + jax.sharding.AxisType.Explicit, + ), + ) + return mesh, P(("data",), None) + + __all__ = [ "MLP", "Block", @@ -231,4 +246,5 @@ def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float "GrugModelConfig", "RMSNorm", "Transformer", + "debug_mesh_and_token_pspec", ] diff --git a/experiments/grug/base/train.py b/experiments/grug/base/train.py index 50dded278d..55385da1b1 100644 --- a/experiments/grug/base/train.py +++ b/experiments/grug/base/train.py @@ -148,7 +148,7 @@ def build_tagged_evaluator( def eval_loss_fn(model: Transformer, batch: LmExample | GrugLmExample) -> tuple[jax.Array, jax.Array, jax.Array]: if isinstance(batch, LmExample): batch = grug_lm_example_from_named(batch) - per_pos_loss = model.compute_next_token_loss( + per_pos_loss = model.next_token_loss( batch.tokens, batch.loss_weight, mask=batch.attn_mask, @@ -224,6 +224,22 @@ class GrugTrainState: ema_params: Transformer +def initial_state( + model_config: GrugModelConfig, + *, + optimizer: optax.GradientTransformation, + mp: jmp.Policy, + key: PRNGKeyArray, +) -> GrugTrainState: + params = mp.cast_to_param(Transformer.init(model_config, key=key)) + return GrugTrainState( + step=jnp.array(0, dtype=jnp.int32), + params=params, + opt_state=optimizer.init(params), + ema_params=params, + ) + + def _make_train_step( optimizer: optax.GradientTransformation, mp: jmp.Policy, @@ -246,7 +262,7 @@ def _make_train_step( def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False): def loss_fn(params): compute_params = mp.cast_to_compute(params) - return compute_params.compute_next_token_loss( + return compute_params.next_token_loss( batch.tokens, batch.loss_weight, mask=batch.attn_mask, @@ -339,12 +355,11 @@ def run_grug(config: GrugRunConfig) -> None: @jax.jit def _init_state(model_rng): - params = trainer.mp.cast_to_param(Transformer.init(config.model, key=model_rng)) - return GrugTrainState( - step=jnp.array(0, dtype=jnp.int32), - params=params, - opt_state=optimizer.init(params), - ema_params=params, + return initial_state( + config.model, + optimizer=optimizer, + mp=trainer.mp, + key=model_rng, ) state = _init_state(model_key) @@ -475,5 +490,6 @@ def _init_state(model_rng): "GrugRunConfig", "GrugTrainState", "GrugTrainerConfig", + "initial_state", "run_grug", ] diff --git a/experiments/grug/moe/__init__.py b/experiments/grug/moe/__init__.py new file mode 100644 index 0000000000..fcd1e658ba --- /dev/null +++ b/experiments/grug/moe/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/experiments/grug/moe/launch.py b/experiments/grug/moe/launch.py new file mode 100644 index 0000000000..5212696cdd --- /dev/null +++ b/experiments/grug/moe/launch.py @@ -0,0 +1,180 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Template: grug-moe trial run. + +This keeps model, train loop, and launch wiring in `experiments/grug/moe` so +the MoE variant can be iterated independently from the dense base template. +""" + +import dataclasses +import os +from dataclasses import dataclass, field +from datetime import timedelta + +import jmp +from fray.cluster import ResourceConfig +from levanter.callbacks.profiler import ProfilerConfig +from levanter.checkpoint import CheckpointerConfig +from levanter.data.text import LmDataConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.tracker import TrackerConfig +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig +from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned +from marin.processing.tokenize import add_validation_sets_to_mixture + +from experiments.defaults import default_validation_sets +from experiments.grug.moe.model import GrugModelConfig +from experiments.grug.moe.train import GrugEvalConfig, GrugRunConfig, GrugTrainerConfig, run_grug +from experiments.tootsie.exp1295_32b import nemotron_mix + + +@dataclass(frozen=True) +class GrugMoeLaunchConfig: + """Last-mile run config for the MoE grug template. + + Keep this as the main entry point for day-to-day edits (model/data/optimizer/trainer/eval knobs). + """ + + model: GrugModelConfig + data: LmDataConfig + output_path: str + run_id: str + steps: int + batch_size: int + seed: int + mp: str # jmp policy string, e.g. "params=float32,compute=bfloat16,output=bfloat16". + tracker: TrackerConfig + optimizer: OptimizerConfig + grug_trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig) + eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig) + + +GRUG_MOE_TRIAL_MODEL = GrugModelConfig( + vocab_size=128_256, + hidden_dim=512, + intermediate_dim=1792, + shared_expert_intermediate_dim=1792, + num_experts=8, + num_experts_per_token=2, + num_layers=6, + num_heads=8, + num_kv_heads=8, + max_seq_len=4096, + head_dim=None, +) + +NEMOTRON_MIX_WITH_DEFAULT_VALIDATION = add_validation_sets_to_mixture( + nemotron_mix, + default_validation_sets(tokenizer=nemotron_mix.tokenizer), +) + + +def _resolve_run_id(default_run_id: str) -> str: + """Resolve run id and append `FERRY_DATE` when launching from ferry workflows.""" + run_id = os.environ.get("GRUG_RUN_ID", default_run_id) + ferry_date = os.environ.get("FERRY_DATE") + if ferry_date: + run_id = f"{run_id}-{ferry_date}" + return run_id + + +def _resolve_tracker(tracker: TrackerConfig, run_id: str) -> TrackerConfig: + if isinstance(tracker, WandbConfig): + return dataclasses.replace(tracker, name=run_id) + return tracker + + +def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None: + # Map template launch knobs onto full Levanter TrainerConfig. + trainer = TrainerConfig( + id=config.run_id, + seed=config.seed, + train_batch_size=config.batch_size, + num_train_steps=config.steps, + profiler=ProfilerConfig(enabled=False, start_step=5, num_steps=100, perfetto_link=False), + mp=jmp.get_policy(config.mp), + tracker=_resolve_tracker(config.tracker, config.run_id), + use_explicit_mesh_axes=True, + require_accelerator=True, + allow_nondivisible_batch_size=False, + checkpointer=CheckpointerConfig( + base_path=os.path.join(config.output_path, "checkpoints"), + append_run_id_to_base_path=False, + save_interval=timedelta(minutes=10), + keep=[{"every": 1000}], + ), + ) + + grug_trainer = dataclasses.replace(config.grug_trainer, trainer=trainer) + + run_config = GrugRunConfig( + model=config.model, + data=config.data, + optimizer=config.optimizer, + trainer=grug_trainer, + eval=config.eval, + ) + run_grug(run_config) + + +RESOLVED_RUN_ID = _resolve_run_id("grug-moe-trial") + + +grug_moe_trial = ExecutorStep( + name="grug/moe-trial", + fn=run_grug_moe_trial, + config=GrugMoeLaunchConfig( + model=versioned(GRUG_MOE_TRIAL_MODEL), + data=NEMOTRON_MIX_WITH_DEFAULT_VALIDATION, + # this_output_path() resolves to this step's output root (e.g. gs://.../grug/moe-trial-). + output_path=this_output_path(), + # Keep run id out of versioning so changing job metadata doesn't create a new output path. + run_id=RESOLVED_RUN_ID, + steps=versioned(2_000), + batch_size=versioned(512), + seed=versioned(0), + mp=versioned("params=float32,compute=bfloat16,output=bfloat16"), + tracker=WandbConfig( + project="marin", + tags=["grug", "template", "moe"], + group="grug-moe-trial", + name=None, # filled from run_id in _resolve_tracker + ), + optimizer=versioned( + AdamConfig( + learning_rate=3e-3, + weight_decay=0.1, + lr_schedule="cosine", + decay=0.2, + min_lr_ratio=0.1, + warmup=1000, + ) + ), + grug_trainer=versioned( + GrugTrainerConfig( + z_loss_weight=1e-4, + ema_beta=None, + log_every=1, + ) + ), + eval=versioned( + GrugEvalConfig( + eval_batch_size=512, + steps_per_eval=200, + max_eval_batches=8, + eval_current=True, + eval_ema=False, + ) + ), + ), + resources=ResourceConfig.with_tpu("v5p-8"), +) + + +if __name__ == "__main__": + executor_main( + steps=[grug_moe_trial], + description="Template grug MoE trial run (~2000 steps) on Nemotron mix.", + ) diff --git a/experiments/grug/moe/model.py b/experiments/grug/moe/model.py new file mode 100644 index 0000000000..09930bda4b --- /dev/null +++ b/experiments/grug/moe/model.py @@ -0,0 +1,510 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""MoE grug variant model. + +This variant intentionally mirrors `experiments/grug/base/model.py` and applies +MoE-specific changes inline. Keeping the file largely self-contained follows the +grug copy-first workflow in `docs/recipes/change_grug.md`. +""" + +import dataclasses + +from dataclasses import dataclass + +import equinox as eqx +import jax +import jax.numpy as jnp +from einops import rearrange +from haliax.jax_utils import named_call +from jax import random +from jax.sharding import PartitionSpec as P +from jax.sharding import get_abstract_mesh, reshard +from jaxtyping import Array, Float, Int, PRNGKeyArray + +from levanter.grug.attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention +from levanter.grug.grug_moe import MoeActivation, moe_mlp +from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss +from levanter.grug.sharding import Pvocab, unshard +from levanter.tracker.histogram import Histogram +from levanter.utils.activation import ActivationFunctionEnum + +_DEFAULT_EP_CAPACITY_FACTOR = 1.25 + + +def _mesh_has_axis(mesh: jax.sharding.AbstractMesh | None, axis_name: str) -> bool: + if mesh is None or mesh.empty: + return False + return axis_name in mesh.shape + + +def _mesh_axis_size(mesh: jax.sharding.AbstractMesh | None, axis_name: str) -> int: + if mesh is None or mesh.empty: + return 1 + return int(mesh.shape.get(axis_name, 1)) + + +def _batch_spec(mesh: jax.sharding.AbstractMesh | None) -> P: + if _mesh_has_axis(mesh, "expert"): + return P(("data", "expert")) + return P(("data",)) + + +@dataclass(frozen=True) +class GrugModelConfig: + """Hyperparameters for the compact grug MoE transformer.""" + + vocab_size: int + hidden_dim: int = 2048 + intermediate_dim: int = 5632 + shared_expert_intermediate_dim: int = 5632 + num_experts: int = 8 + num_experts_per_token: int = 2 + num_layers: int = 24 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + max_seq_len: int = 4096 + layer_norm_eps: float = 1e-5 + initializer_std: float = 0.02 + rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig) + + def __post_init__(self) -> None: + _ = self.inferred_head_dim + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads for grouped-query attention") + if self.vocab_size <= 0: + raise ValueError("vocab_size must be positive") + if self.max_seq_len <= 0: + raise ValueError("max_seq_len must be positive") + if self.num_experts <= 0: + raise ValueError("num_experts must be positive") + if self.num_experts_per_token <= 0: + raise ValueError("num_experts_per_token must be positive") + if self.num_experts_per_token > self.num_experts: + raise ValueError("num_experts_per_token must be <= num_experts") + if self.shared_expert_intermediate_dim < 0: + raise ValueError("shared_expert_intermediate_dim must be non-negative") + + @property + def inferred_head_dim(self) -> int: + if self.head_dim is not None: + return self.head_dim + if self.hidden_dim % self.num_heads != 0: + raise ValueError( + f"hidden_dim={self.hidden_dim} is not divisible by num_heads={self.num_heads}; set head_dim explicitly" + ) + return self.hidden_dim // self.num_heads + + +class CausalSelfAttention(eqx.Module): + w_q: Float[Array, "D NH"] + w_k: Float[Array, "D MH"] + w_v: Float[Array, "D MH"] + w_o: Float[Array, "NH D"] + cfg: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "CausalSelfAttention": + k_q, k_k, k_v, k_o = random.split(key, 4) + d, n, m, h = cfg.hidden_dim, cfg.num_heads, cfg.num_kv_heads, cfg.inferred_head_dim + return CausalSelfAttention( + w_q=reshard(_init_weight(k_q, (d, n * h), cfg.initializer_std), P("data", "model")), + w_k=reshard(_init_weight(k_k, (d, m * h), cfg.initializer_std), P("data", "model")), + w_v=reshard(_init_weight(k_v, (d, m * h), cfg.initializer_std), P("data", "model")), + w_o=reshard(_init_weight(k_o, (n * h, d), cfg.initializer_std), P("model", "data")), + cfg=cfg, + ) + + @named_call + def __call__(self, x: Float[Array, "B S D"], mask: AttentionMask | jax.Array) -> Float[Array, "B S D"]: + head_dim = self.cfg.inferred_head_dim + seq_len = x.shape[1] + batch_spec = _batch_spec(get_abstract_mesh()) + + q = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_q), "... (n d) -> ... n d", d=head_dim) + k = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_k), "... (m d) -> ... m d", d=head_dim) + v = rearrange(jnp.einsum("bsh,hd->bsd", x, self.w_v), "... (m d) -> ... m d", d=head_dim) + q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=self.cfg.rope) + attn_out = attention(q, k, v, mask) + attn_out = rearrange(attn_out, "... n d -> ... (n d)") + return jnp.einsum("bsh,hd->bsd", attn_out, self.w_o, out_sharding=batch_spec) + + +class RMSNorm(eqx.Module): + weight: jax.Array + eps: float = eqx.field(static=True) + + @staticmethod + def init(dim: int, eps: float) -> "RMSNorm": + return RMSNorm(weight=jnp.ones((dim,), dtype=jnp.float32), eps=eps) + + @named_call + def __call__(self, x: Float[Array, "... D"]) -> Float[Array, "... D"]: + weight = unshard(self.weight) + dtype = x.dtype + x = x.astype(jnp.float32) + variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + normed = x * jax.lax.rsqrt(variance + self.eps) + return (normed * weight).astype(dtype) + + +def _shared_dense_mlp( + x: Float[Array, "B S D"], + shared_w_up_gate: Float[Array, "D J2"], + shared_w_down: Float[Array, "J D"], + *, + activation: MoeActivation = ActivationFunctionEnum.silu, +) -> Float[Array, "B S D"]: + if isinstance(activation, ActivationFunctionEnum): + activation_fn = activation.to_jax_fn() + else: + activation_fn = activation + + b, s, _ = x.shape + x_flat = rearrange(x, "b s d -> (b s) d") + shared_dim = shared_w_down.shape[0] + shared_up_gate = jnp.einsum("td,dm->tm", x_flat, shared_w_up_gate) + shared_gate, shared_up = jnp.split(shared_up_gate, [shared_dim], axis=-1) + out_flat = jnp.einsum("tm,md->td", activation_fn(shared_gate) * shared_up, shared_w_down) + return rearrange(out_flat, "(b s) d -> b s d", b=b, s=s) + + +def _routing_stats_from_selected_experts( + selected_experts: Int[Array, "T K"], + *, + num_experts: int, +) -> dict[str, jax.Array]: + expert_counts = jnp.sum(jax.nn.one_hot(selected_experts, num_experts, dtype=jnp.float32), axis=(0, 1)) + total_assignments = jnp.maximum(jnp.sum(expert_counts), 1.0) + expert_loads = expert_counts / total_assignments + routing_entropy = -jnp.sum(expert_loads * jnp.log(expert_loads + 1e-6)) + return { + "routing_counts": expert_counts, + "routing_entropy": routing_entropy, + } + + +def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str, jax.Array | Histogram]: + routing_entropy = router_metrics["routing_entropy_per_layer"] + routing_counts = router_metrics["routing_counts_per_layer"] + num_layers = int(routing_entropy.shape[0]) + + out: dict[str, jax.Array | Histogram] = { + "train/router/routing_entropy_mean": jnp.mean(routing_entropy), + } + for i in range(num_layers): + out[f"train/router/layer_{i}/routing_entropy"] = routing_entropy[i] + out[f"train/router/layer_{i}/routing_hist"] = _histogram_from_expert_counts(routing_counts[i]) + return out + + +def _histogram_from_expert_counts(expert_counts: jax.Array) -> Histogram: + counts = jnp.asarray(expert_counts, dtype=jnp.float32) + num_experts = counts.shape[0] + expert_ids = jnp.arange(num_experts, dtype=jnp.float32) + num = jnp.sum(counts) + sum_values = jnp.sum(counts * expert_ids) + sum_squares = jnp.sum(counts * expert_ids * expert_ids) + nonzero = counts > 0 + min_value = jnp.where(nonzero, expert_ids, jnp.inf).min() + max_value = jnp.where(nonzero, expert_ids, -jnp.inf).max() + min_value = jnp.where(num > 0, min_value, 0.0) + max_value = jnp.where(num > 0, max_value, 0.0) + bucket_limits = jnp.arange(num_experts + 1, dtype=jnp.float32) + return Histogram( + min=min_value, + max=max_value, + num=num, + sum=sum_values, + sum_squares=sum_squares, + bucket_limits=bucket_limits, + bucket_counts=counts, + ) + + +class MoEMLP(eqx.Module): + router: jax.Array + w_up_gate: jax.Array + w_down: jax.Array + shared_w_up_gate: jax.Array | None + shared_w_down: jax.Array | None + cfg: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP": + k_router, k_w_up_gate, k_w_down, k_shared_up_gate, k_shared_down = random.split(key, 5) + mesh = get_abstract_mesh() + + expert_axis_size = _mesh_axis_size(mesh, "expert") + if cfg.num_experts % expert_axis_size != 0: + raise ValueError(f"num_experts={cfg.num_experts} must be divisible by expert axis size={expert_axis_size}") + + expert_param_spec = P("expert", None, None) if _mesh_has_axis(mesh, "expert") else P(None, None, None) + + d, e, i, j = ( + cfg.hidden_dim, + cfg.num_experts, + cfg.intermediate_dim, + cfg.shared_expert_intermediate_dim, + ) + + shared_w_up_gate = None + shared_w_down = None + if j > 0: + # Keep shared expert weights replicated in this compact variant. + # This avoids introducing additional sharding/layout complexity + # while we iterate on routed-expert behavior. + shared_w_up_gate = reshard(_init_weight(k_shared_up_gate, (d, 2 * j), cfg.initializer_std), P(None, None)) + shared_w_down = reshard(_init_weight(k_shared_down, (j, d), cfg.initializer_std), P(None, None)) + + return MoEMLP( + router=reshard(_init_weight(k_router, (d, e), cfg.initializer_std), P(None, None)), + w_up_gate=reshard(_init_weight(k_w_up_gate, (e, d, 2 * i), cfg.initializer_std), expert_param_spec), + w_down=reshard(_init_weight(k_w_down, (e, i, d), cfg.initializer_std), expert_param_spec), + shared_w_up_gate=shared_w_up_gate, + shared_w_down=shared_w_down, + cfg=cfg, + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + *, + return_router_stats: bool = False, + ) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + b, s, _ = x.shape + x_flat = rearrange(x, "b s d -> (b s) d") + router_logits = jnp.einsum("td,de->te", x_flat, reshard(self.router, P(None, None))) + topk_logits, selected_experts = jax.lax.top_k(router_logits, self.cfg.num_experts_per_token) + combine_weights = jax.nn.softmax(topk_logits, axis=-1).astype(x.dtype) + router_stats = ( + _routing_stats_from_selected_experts(selected_experts.astype(jnp.int32), num_experts=self.cfg.num_experts) + if return_router_stats + else None + ) + + routed_flat = moe_mlp( + x_flat, + selected_experts.astype(jnp.int32), + combine_weights, + self.w_up_gate, + self.w_down, + activation=ActivationFunctionEnum.silu, + mesh=get_abstract_mesh(), + capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR, + ) + routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s) + routed = reshard(routed, _batch_spec(get_abstract_mesh())) + + out = routed + if self.shared_w_up_gate is None: + assert self.shared_w_down is None + else: + assert self.shared_w_down is not None + shared_out = _shared_dense_mlp( + x, + self.shared_w_up_gate, + self.shared_w_down, + activation=ActivationFunctionEnum.silu, + ) + out = routed + shared_out + + if return_router_stats: + assert router_stats is not None + return out, router_stats + return out + + +class Block(eqx.Module): + rms_attn: RMSNorm + attn: CausalSelfAttention + rms_mlp: RMSNorm + mlp: MoEMLP + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Block": + attn_key, mlp_key = random.split(key, 2) + return Block( + rms_attn=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps), + attn=CausalSelfAttention.init(cfg, key=attn_key), + rms_mlp=RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps), + mlp=MoEMLP.init(cfg, key=mlp_key), + ) + + @named_call + def __call__( + self, + x: Float[Array, "B S D"], + mask: AttentionMask | jax.Array, + *, + return_router_stats: bool = False, + ) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + x = x + self.attn(self.rms_attn(x), mask) + if return_router_stats: + mlp_out, router_stats = self.mlp(self.rms_mlp(x), return_router_stats=True) + x = x + mlp_out + return x, router_stats + + x = x + self.mlp(self.rms_mlp(x)) + return x + + +class Transformer(eqx.Module): + token_embed: jax.Array + output_proj: jax.Array + blocks: tuple[Block, ...] + final_norm: RMSNorm + config: GrugModelConfig = eqx.field(static=True) + + @staticmethod + def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "Transformer": + embed_key, out_key, *block_keys = random.split(key, cfg.num_layers + 2) + token_embed = reshard(_init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), Pvocab) + output_proj = reshard(_init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std), Pvocab) + blocks = tuple(Block.init(cfg, key=layer_key) for layer_key in block_keys) + final_norm = RMSNorm.init(cfg.hidden_dim, cfg.layer_norm_eps) + + return Transformer( + token_embed=token_embed, + output_proj=output_proj, + blocks=blocks, + final_norm=final_norm, + config=cfg, + ) + + @named_call + def __call__( + self, + token_ids: Int[Array, "B S"], + mask: AttentionMask | jax.Array | None = None, + *, + return_router_stats: bool = False, + ) -> Float[Array, "B S D"] | tuple[Float[Array, "B S D"], dict[str, jax.Array]]: + if mask is None: + mask = AttentionMask.causal() + + batch_spec = _batch_spec(get_abstract_mesh()) + hidden = self.token_embed.at[token_ids].get(out_sharding=batch_spec) + if return_router_stats: + all_router_stats: list[dict[str, jax.Array]] = [] + for block in self.blocks: + hidden, router_stats = eqx.filter_checkpoint(block)(hidden, mask, return_router_stats=True) + all_router_stats.append(router_stats) + + router_metrics = { + "routing_entropy_per_layer": jnp.stack([s["routing_entropy"] for s in all_router_stats], axis=0), + "routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0), + } + return self.final_norm(hidden), router_metrics + + for block in self.blocks: + hidden = eqx.filter_checkpoint(block)(hidden, mask) + return self.final_norm(hidden) + + @named_call + def logits( + self, + token_ids: Int[Array, "B S"], + mask: AttentionMask | jax.Array | None = None, + ) -> Float[Array, "B S V"]: + batch_spec = _batch_spec(get_abstract_mesh()) + hidden = self(token_ids, mask=mask) + return jnp.einsum("bsh,hd->bsd", hidden, self.output_proj, out_sharding=batch_spec) + + def compute_loss( + self, + token_ids: Int[Array, "B S"], + loss_weight: Float[Array, "B S"], + *, + mask: AttentionMask | jax.Array | None = None, + reduction: str = "mean", + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype = jnp.float32, + return_router_metrics: bool = False, + ) -> jax.Array | tuple[jax.Array, dict[str, jax.Array | Histogram]]: + """Compute next-token cross-entropy loss for a batch.""" + router_metrics: dict[str, jax.Array] | None = None + if return_router_metrics: + hidden, router_metrics = self(token_ids, mask=mask, return_router_stats=True) + else: + hidden = self(token_ids, mask=mask) + labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) + loss_weight = loss_weight.astype(loss_dtype) + + loss = fused_linear_softmax_cross_entropy_loss( + hidden, + self.output_proj, + labels, + weight=loss_weight, + reduction=reduction, + logsumexp_weight=logsumexp_weight, + dtype=loss_dtype, + ) + if return_router_metrics: + assert router_metrics is not None + return loss, _summarize_router_metrics(router_metrics) + return loss + + def next_token_loss( + self, + token_ids: Int[Array, "B S"], + loss_weight: Float[Array, "B S"], + *, + mask: AttentionMask | jax.Array | None = None, + reduction: str = "mean", + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype = jnp.float32, + return_router_metrics: bool = False, + ) -> jax.Array | tuple[jax.Array, dict[str, jax.Array | Histogram]]: + return self.compute_loss( + token_ids, + loss_weight, + mask=mask, + reduction=reduction, + logsumexp_weight=logsumexp_weight, + loss_dtype=loss_dtype, + return_router_metrics=return_router_metrics, + ) + + +def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]: + return std * random.truncated_normal(key, -3, 3, shape) + + +def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractMesh, P]: + """Return a small abstract mesh and token sharding for lowering contract tests.""" + if num_devices <= 0: + raise ValueError(f"num_devices must be positive, got {num_devices}") + # Keep expert axis at 2 when possible to exercise EP lowering, otherwise + # fall back to expert=1. + expert = 2 if num_devices % 2 == 0 else 1 + data = max(1, num_devices // expert) + mesh = jax.sharding.AbstractMesh( + axis_sizes=(data, expert, 1), + axis_names=("data", "expert", "model"), + axis_types=( + jax.sharding.AxisType.Explicit, + jax.sharding.AxisType.Explicit, + jax.sharding.AxisType.Explicit, + ), + ) + return mesh, P(("data", "expert"), None) + + +GrugMoeModelConfig = GrugModelConfig + + +__all__ = [ + "Block", + "CausalSelfAttention", + "GrugModelConfig", + "GrugMoeModelConfig", + "MoEMLP", + "MoeActivation", + "RMSNorm", + "Transformer", + "debug_mesh_and_token_pspec", + "moe_mlp", +] diff --git a/experiments/grug/moe/train.py b/experiments/grug/moe/train.py new file mode 100644 index 0000000000..65bef58c03 --- /dev/null +++ b/experiments/grug/moe/train.py @@ -0,0 +1,507 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import dataclasses +import functools +import logging +import time +from dataclasses import dataclass, field + +import jax +import jax.numpy as jnp +import jmp +import optax +from haliax import Axis +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jax.tree_util import register_dataclass +from jaxtyping import PRNGKeyArray + +import levanter.callbacks as callbacks +import levanter.tracker +from levanter.callbacks.state_adapter import StateCallbackRunner +from levanter.callbacks.watch import WatchConfig, compute_watch_stats +from levanter.checkpoint import load_checkpoint +from levanter.data import AsyncDataset, DataLoader +from levanter.data.mixture import MixtureDataset, rescale_mixture_schedule_for_batch_schedule +from levanter.data.text import GrugLmExample, LmDataConfig +from levanter.data.text.examples import grug_lm_example_from_named +from levanter.eval import TaggedEvaluator, cb_tagged_evaluate +from levanter.models.lm_model import LmExample +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.schedule import BatchSchedule +from levanter.trainer import TrainerConfig +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.jax_utils import parameter_count +from levanter.utils.logging import LoadingTimeTrackerIterator + +from experiments.grug.moe.model import GrugModelConfig, Transformer + +# This file intentionally mirrors `experiments/grug/base/train.py` with +# variant-specific model/loss/FLOP wiring, per the grug copy-first workflow in +# `docs/recipes/change_grug.md`. + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class GrugTrainerConfig: + """Runtime knobs for grug training.""" + + trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(use_explicit_mesh_axes=True)) + train_batch_pspec: P = field(default_factory=lambda: P(("data",))) + data_seed: int | None = None + log_every: int = 1 + ema_beta: float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA. + z_loss_weight: float = 0.0 # Weight on logsumexp (z-loss) stabilization term. + + +@dataclass(frozen=True) +class GrugEvalConfig: + """Perplexity eval settings for grug training.""" + + eval_batch_size: int = 512 + eval_batch_pspec: P = field(default_factory=lambda: P(("data",))) + steps_per_eval: int | None = 1000 + max_eval_batches: int | None = None + prefix: str = "eval" + eval_current: bool = True + eval_ema: bool = True + compute_bpb: bool = True + + +@dataclass(frozen=True) +class GrugRunConfig: + """Top-level config for grug training.""" + + model: GrugModelConfig + data: LmDataConfig + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + trainer: GrugTrainerConfig = field(default_factory=GrugTrainerConfig) + eval: GrugEvalConfig | None = field(default_factory=GrugEvalConfig) + + +def build_train_dataset( + data_config: LmDataConfig, + *, + max_seq_len: int, + batch_schedule: BatchSchedule, + key: PRNGKeyArray, +) -> MixtureDataset[GrugLmExample]: + pos = Axis("position", max_seq_len) + mix_key, shuffle_key = jax.random.split(key) + weights = data_config.train_weights + if isinstance(weights, list): + weights = rescale_mixture_schedule_for_batch_schedule(weights, batch_schedule) + + initial_batch_size = batch_schedule.batch_size_at_step(0) + datasets = data_config.train_sets(pos, key=shuffle_key, initial_batch_size=initial_batch_size) + return MixtureDataset( + datasets=datasets, + weights=weights, + stop_strategy=data_config.stop_strategy, + key=mix_key, + block_size=data_config.mixture_block_size, + ) + + +def build_train_loader( + dataset: AsyncDataset[GrugLmExample], + *, + batch_schedule: BatchSchedule, + mesh: Mesh, + batch_pspec: P = P(("data",)), +) -> DataLoader[GrugLmExample]: + # DataLoader uses this batch axis mapping to shard batches across the distributed mesh. + axis_resource = batch_pspec[0] + return DataLoader( + dataset, + batch_schedule.schedule, + mesh=mesh, + axis_resources={"__BATCH__": axis_resource}, + batch_axis_name="__BATCH__", + allow_nondivisible_batch_size=False, + ) + + +def build_tagged_evaluator( + *, + data_config: LmDataConfig, + max_seq_len: int, + mesh: Mesh, + eval_cfg: GrugEvalConfig, +) -> TaggedEvaluator[LmExample | GrugLmExample, Transformer] | None: + pos = Axis("position", max_seq_len) + tagged_eval_sets = data_config.tagged_eval_sets(pos) + if len(tagged_eval_sets) == 0: + logger.warning("No evaluation datasets provided.") + return None + + max_examples_per_dataset = None + if eval_cfg.max_eval_batches is not None: + max_examples_per_dataset = eval_cfg.max_eval_batches * eval_cfg.eval_batch_size + + tokenizer = data_config.the_tokenizer if eval_cfg.compute_bpb else None + batch_axis_resource = eval_cfg.eval_batch_pspec[0] + eval_axis_mapping = {"batch": batch_axis_resource} + eval_batch = Axis("batch", eval_cfg.eval_batch_size) + eval_array_sharding = NamedSharding(mesh, P(batch_axis_resource, None)) + + def eval_loss_fn(model: Transformer, batch: LmExample | GrugLmExample) -> tuple[jax.Array, jax.Array, jax.Array]: + if isinstance(batch, LmExample): + batch = grug_lm_example_from_named(batch) + per_pos_loss = model.next_token_loss( + batch.tokens, + batch.loss_weight, + mask=batch.attn_mask, + reduction="none", + logsumexp_weight=None, + ) + per_pos_loss = jax.sharding.reshard(per_pos_loss, eval_array_sharding) + per_pos_weight = jax.sharding.reshard(batch.loss_weight, eval_array_sharding) + per_pos_token_id = jnp.roll(batch.tokens, -1, axis=-1) + return per_pos_loss, per_pos_weight, per_pos_token_id + + return TaggedEvaluator( + EvalBatch=eval_batch, + tagged_eval_sets=tagged_eval_sets, + loss_fn=eval_loss_fn, + tokenizer=tokenizer, + device_mesh=mesh, + axis_mapping=eval_axis_mapping, + max_examples_per_dataset=max_examples_per_dataset, + ) + + +def _compute_flops( + *, + model_config: GrugModelConfig, +) -> tuple[float, dict[str, float]]: + flops_per_token = lm_flops_per_token( + hidden_dim=model_config.hidden_dim, + intermediate_dim=model_config.intermediate_dim, + num_layers=model_config.num_layers, + num_kv_heads=model_config.num_kv_heads, + num_heads=model_config.num_heads, + seq_len=model_config.max_seq_len, + vocab_size=model_config.vocab_size, + glu=True, + num_experts=model_config.num_experts, + num_shared_experts=1 if model_config.shared_expert_intermediate_dim > 0 else 0, + num_experts_per_tok=model_config.num_experts_per_token, + ) + flops_per_example = 3 * flops_per_token * model_config.max_seq_len + + flops_summary: dict[str, float] = { + "throughput/flops_per_token_analytic": flops_per_token, + "throughput/flops_per_example_analytic": flops_per_example, + } + + return flops_per_example, flops_summary + + +def _make_mixture_stage_callback(train_dataset: MixtureDataset, batch_schedule: BatchSchedule): + last_mixture_stage = -1 + + def log_mixture_stage(step_info): + nonlocal last_mixture_stage + seq_index = batch_schedule.global_data_offset_by_step(step_info.step) + block_id = seq_index // train_dataset.block_size + stage = train_dataset._get_stage_for_block(block_id) + if stage == last_mixture_stage: + return + + weights = train_dataset.weight_stages[stage][1] + mixture_log = {f"mixture/weight/{name}": weight for name, weight in weights.items()} + mixture_log["mixture/stage"] = stage + levanter.tracker.log(mixture_log, step=step_info.step) + last_mixture_stage = stage + + return log_mixture_stage + + +@register_dataclass +@dataclass(frozen=True) +class GrugTrainState: + step: jax.Array + params: Transformer + opt_state: optax.OptState + ema_params: Transformer + + +def initial_state( + model_config: GrugModelConfig, + *, + optimizer: optax.GradientTransformation, + mp: jmp.Policy, + key: PRNGKeyArray, +) -> GrugTrainState: + params = mp.cast_to_param(Transformer.init(model_config, key=key)) + return GrugTrainState( + step=jnp.array(0, dtype=jnp.int32), + params=params, + opt_state=optimizer.init(params), + ema_params=params, + ) + + +def _make_train_step( + optimizer: optax.GradientTransformation, + mp: jmp.Policy, + *, + z_loss_weight: float, + ema_beta: float | None, + watch_config: WatchConfig | None = None, +): + one = jnp.array(1, dtype=jnp.int32) + z_loss = z_loss_weight if z_loss_weight > 0 else None + if watch_config is not None: + if isinstance(watch_config.watch_targets, str): + watch_targets = tuple(t.strip() for t in watch_config.watch_targets.split(",")) + else: + watch_targets = tuple(watch_config.watch_targets) + else: + watch_targets = () + + @functools.partial(jax.jit, donate_argnums=(0,), static_argnames=("compute_watch",)) + def train_step(state: GrugTrainState, batch, *, compute_watch: bool = False): + def loss_fn(params): + compute_params = mp.cast_to_compute(params) + return compute_params.next_token_loss( + batch.tokens, + batch.loss_weight, + mask=batch.attn_mask, + reduction="mean", + logsumexp_weight=z_loss, + return_router_metrics=True, + ) + + (loss, summarized_metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) + metrics = {"train/loss": loss, **summarized_metrics} + updates, opt_state = optimizer.update(grads, state.opt_state, state.params) + params = optax.apply_updates(state.params, updates) + + if ema_beta is None: + ema_params = params + else: + ema_params = jax.tree_util.tree_map( + lambda old, new: ema_beta * old + (1.0 - ema_beta) * new, + state.ema_params, + params, + ) + + watch_stats = None + if watch_config is not None and compute_watch: + watch_stats = compute_watch_stats( + watch_targets=watch_targets, + include_norms=watch_config.include_norms, + include_per_parameter_norms=watch_config.include_per_parameter_norms, + include_histogram=watch_config.include_histograms, + split_scan_layers=watch_config.split_scan_layers, + params=state.params, + grads=grads, + updates=updates, + opt_state=state.opt_state, + model_tree_type=type(state.params), + ) + + next_state = dataclasses.replace( + state, + step=state.step + one, + params=params, + opt_state=opt_state, + ema_params=ema_params, + ) + + return next_state, metrics, watch_stats + + return train_step + + +def run_grug(config: GrugRunConfig) -> None: + """Entry point for the grug template training loop.""" + trainer = config.trainer.trainer + trainer.initialize() + levanter.tracker.log_configuration(config) + + run_id = trainer.id + if run_id is None: + raise ValueError("trainer.id was not initialized") + + optimizer = config.optimizer.build(trainer.num_train_steps) + watch_config = trainer.watch + train_step = _make_train_step( + optimizer, + trainer.mp, + z_loss_weight=config.trainer.z_loss_weight, + ema_beta=config.trainer.ema_beta, + watch_config=watch_config if watch_config.is_enabled else None, + ) + + data_key, model_key = jax.random.split(jax.random.PRNGKey(trainer.seed), 2) + if config.trainer.data_seed is not None: + data_key = jax.random.PRNGKey(config.trainer.data_seed) + + # Build data/model state under the trainer mesh so all arrays are sharded consistently. + with trainer.use_device_mesh(): + mesh = trainer.device_mesh + batch_schedule = trainer.batch_schedule + + train_dataset = build_train_dataset( + config.data, + max_seq_len=config.model.max_seq_len, + batch_schedule=batch_schedule, + key=data_key, + ) + train_loader = build_train_loader( + train_dataset, + batch_schedule=batch_schedule, + mesh=mesh, + batch_pspec=config.trainer.train_batch_pspec, + ) + + @jax.jit + def _init_state(model_rng): + return initial_state( + config.model, + optimizer=optimizer, + mp=trainer.mp, + key=model_rng, + ) + + state = _init_state(model_key) + + checkpointer = trainer.checkpointer.create(run_id) + checkpoint_path = trainer.load_checkpoint_path + if checkpoint_path is None and checkpointer is not None: + checkpoint_path = trainer.checkpointer.expanded_path(run_id) + if checkpoint_path is None: + if trainer.load_checkpoint: + raise FileNotFoundError("load_checkpoint=True but no checkpoint path is configured.") + elif trainer.load_checkpoint is not False: + try: + state = load_checkpoint( + state, + checkpoint_path, + discover_latest=True, + axis_mapping=None, + mesh=mesh, + allow_partial=trainer.allow_partial_checkpoint, + ) + except FileNotFoundError: + if trainer.load_checkpoint is True: + raise + logger.info(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.") + + levanter.tracker.log_summary({"parameter_count": parameter_count(state.params)}) + + flops_per_example, flops_summary = _compute_flops(model_config=config.model) + levanter.tracker.log_summary(flops_summary) + + eval_cfg = config.eval + evaluator = None + if eval_cfg is not None: + evaluator = build_tagged_evaluator( + data_config=config.data, + max_seq_len=config.model.max_seq_len, + mesh=mesh, + eval_cfg=eval_cfg, + ) + + profiler_cfg = trainer.profiler + profiler_num_steps = profiler_cfg.resolve_num_profile_steps(num_train_steps=trainer.num_train_steps) + profiler_enabled = profiler_cfg.is_enabled and profiler_num_steps > 0 + + log_every = max(1, config.trainer.log_every) + iterator = LoadingTimeTrackerIterator(train_loader.iter_from_step(int(state.step))) + + state_callbacks = StateCallbackRunner[GrugTrainState]( + step_getter=lambda s: s.step, + model_getter=lambda s: s.params, + eval_model_getter=lambda s: s.ema_params, + opt_state_getter=lambda s: s.opt_state, + ) + state_callbacks.add_hook( + callbacks.log_performance_stats(config.model.max_seq_len, batch_schedule, flops_per_example), + every=log_every, + ) + state_callbacks.add_hook(callbacks.pbar_logger(total=trainer.num_train_steps), every=log_every) + state_callbacks.add_hook(callbacks.log_step_info(trainer.num_train_steps), every=log_every) + if profiler_enabled: + state_callbacks.add_hook( + callbacks.profile( + str(trainer.log_dir / run_id / "profiler"), + profiler_cfg.start_step, + profiler_num_steps, + profiler_cfg.perfetto_link, + ), + every=1, + ) + state_callbacks.add_hook(_make_mixture_stage_callback(train_dataset, batch_schedule), every=1) + if evaluator is not None and eval_cfg is not None: + interval = eval_cfg.steps_per_eval + eval_ema = eval_cfg.eval_ema and config.trainer.ema_beta is not None + if interval is not None and interval > 0 and (eval_cfg.eval_current or eval_ema): + state_callbacks.add_hook( + cb_tagged_evaluate( + evaluator, + prefix=eval_cfg.prefix, + eval_current=eval_cfg.eval_current, + eval_ema=eval_ema, + ), + every=interval, + ) + + last_loss: float | jax.Array = 0.0 + last_step_duration = 0.0 + + # Main optimization loop. + try: + while int(state.step) < trainer.num_train_steps: + batch = next(iterator) + step_start = time.perf_counter() + current_step = int(state.step) + # grad_watch runs only on its configured interval. + compute_watch = ( + watch_config.is_enabled and watch_config.interval > 0 and current_step % watch_config.interval == 0 + ) + state, metrics, watch_stats = train_step(state, batch, compute_watch=compute_watch) + step = int(state.step) - 1 + + jax.block_until_ready(metrics["train/loss"]) + duration = time.perf_counter() - step_start + hook_start = time.perf_counter() + state_callbacks.run(state, loss=metrics["train/loss"], step_duration=duration) + last_loss = metrics["train/loss"] + last_step_duration = duration + levanter.tracker.log({"throughput/hook_time": time.perf_counter() - hook_start}, step=step) + levanter.tracker.log({"throughput/loading_time": iterator.this_load_time}, step=step) + router_metrics = {key: value for key, value in metrics.items() if key.startswith("train/router/")} + if router_metrics: + levanter.tracker.log(router_metrics, step=step) + + if watch_stats is not None: + levanter.tracker.log(watch_stats, step=step) + + if checkpointer is not None: + checkpointer.on_step(tree={"train_state": state}, step=int(state.step)) + finally: + # Mirror classic trainer behavior: force callbacks on the last completed step. + state_callbacks.run(state, loss=last_loss, step_duration=last_step_duration, force=True) + if checkpointer is not None: + checkpointer.on_step(tree={"train_state": state}, step=int(state.step), force=True) + checkpointer.wait_until_finished() + + levanter.tracker.current_tracker().finish() + + +__all__ = [ + "GrugEvalConfig", + "GrugRunConfig", + "GrugTrainState", + "GrugTrainerConfig", + "initial_state", + "run_grug", +] diff --git a/lib/levanter/src/levanter/grug/grug_moe.py b/lib/levanter/src/levanter/grug/grug_moe.py new file mode 100644 index 0000000000..9d3d3069a3 --- /dev/null +++ b/lib/levanter/src/levanter/grug/grug_moe.py @@ -0,0 +1,325 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Canonical compact Grug MoE kernels. + +Implementation overview: +- Routing keeps the argsort-grouped dispatch path that emerged as the stable + default from https://github.com/marin-community/marin/issues/2704 and commit + 89318a910 (and its parent). +- Expert parallelism keeps the ring-style strategy from + https://github.com/marin-community/marin/issues/2710: token-sharded + `all_gather` for dispatch, then `psum_scatter` for collection. +- This module intentionally provides functional kernels only; model/module + wiring lives in the Grug model files. +""" + +import math + +from collections.abc import Callable +from functools import partial +from typing import TypeAlias + +import jax +import jax.numpy as jnp +from haliax.jax_utils import named_call +from jax import shard_map +from jax.sharding import PartitionSpec as P, get_abstract_mesh +from jaxtyping import Array, Float, Int + +from haliax.nn.linear import gmm_sharded +from levanter.utils.activation import ActivationFunctionEnum + +_DEFAULT_EP_CAPACITY_FACTOR = 1.25 +# #2710 used 1.25 as the practical EP ring default to avoid over/under-packing. + +MoeActivation: TypeAlias = ActivationFunctionEnum | Callable[[jax.Array], jax.Array] + + +def _mesh_has_axis(mesh: jax.sharding.AbstractMesh | None, axis_name: str) -> bool: + if mesh is None or mesh.empty: + return False + return axis_name in mesh.shape + + +def _mesh_axis_size(mesh: jax.sharding.AbstractMesh | None, axis_name: str) -> int: + if mesh is None or mesh.empty: + return 1 + return int(mesh.shape.get(axis_name, 1)) + + +def _batch_spec(mesh: jax.sharding.AbstractMesh | None) -> P: + if _mesh_has_axis(mesh, "expert"): + return P(("data", "expert")) + return P(("data",)) + + +def _prepare_moe_dispatch( + x: Float[Array, "T D"], + selected_experts: Int[Array, "T K"], + combine_weights: Float[Array, "T K"], + *, + num_experts: int, +) -> tuple[ + Float[Array, "TK D"], + Float[Array, "TK"], + Int[Array, "TK"], + Int[Array, "E"], +]: + """Flatten + argsort by expert into grouped layout for GMM.""" + # #2704: keep argsort-grouped dispatch as the canonical compact routing + # strategy, matching the behavior carried forward from 89318a910. + tokens, topk = selected_experts.shape + expert_ids = selected_experts.reshape(tokens * topk) + dispatch_weights = combine_weights.reshape(tokens * topk) + + sort_idx = jnp.argsort(expert_ids, axis=0) + token_ids = jnp.arange(tokens * topk, dtype=jnp.int32) // topk + token_ids_sort = token_ids[sort_idx] + x_sort = x[token_ids_sort] + w_sort = dispatch_weights[sort_idx].astype(x.dtype) + group_sizes = jnp.bincount(expert_ids, length=num_experts).astype(jnp.int32) + return x_sort, w_sort, token_ids_sort, group_sizes + + +def _moe_mlp_local( + x: Float[Array, "T D"], + selected_experts: Int[Array, "T K"], + combine_weights: Float[Array, "T K"], + moe_w13: Float[Array, "E D I2"], + moe_w2: Float[Array, "E I D"], + *, + activation_fn: Callable[[jax.Array], jax.Array], + num_experts: int, +) -> tuple[Float[Array, "T D"], Int[Array, ""]]: + """Per-shard non-EP MoE FFN path with argsort routing + grouped matmul.""" + x_dispatch, w_dispatch, token_dispatch, group_sizes = _prepare_moe_dispatch( + x, + selected_experts, + combine_weights, + num_experts=num_experts, + ) + + w13_out = gmm_sharded(x_dispatch, moe_w13, group_sizes) + moe_dim = moe_w2.shape[1] + gate, up = jnp.split(w13_out, [moe_dim], axis=-1) + out_dispatch = gmm_sharded(activation_fn(gate) * up, moe_w2, group_sizes) + + out = jnp.zeros_like(x).at[token_dispatch].add(out_dispatch * w_dispatch[:, None], mode="drop") + return out, jnp.array(0, dtype=jnp.int32) + + +def _batch_spec_from_x(x: jax.Array, mesh: jax.sharding.AbstractMesh | None) -> P: + sharding = getattr(x, "sharding", None) + spec = getattr(sharding, "spec", None) + if spec is not None and len(spec) > 0: + return P(spec[0]) + return _batch_spec(mesh) + + +def _moe_mlp_ep_ring_local( + x_local: Float[Array, "TL D"], + selected_experts_local: Int[Array, "TL K"], + combine_weights_local: Float[Array, "TL K"], + moe_w13_local: Float[Array, "EL D I2"], + moe_w2_local: Float[Array, "EL I D"], + *, + activation_fn: Callable[[jax.Array], jax.Array], + num_experts: int, + capacity_factor: float, +) -> tuple[Float[Array, "TL D"], Int[Array, ""]]: + """Ring-style EP routed path: all-gather dispatch + psum-scatter collect.""" + # #2710 ring EP strategy: gather tokens and their selected-expert routing + # assignments across expert shards, then psum-scatter back to local tokens. + x_global = jax.lax.all_gather(x_local, "expert", tiled=True) + selected_experts_global = jax.lax.all_gather(selected_experts_local, "expert", tiled=True) + combine_weights_global = jax.lax.all_gather(combine_weights_local, "expert", tiled=True) + + tokens = x_global.shape[0] + topk = selected_experts_global.shape[1] + assignments = tokens * topk + expert_flat = selected_experts_global.reshape(assignments) + weight_flat = combine_weights_global.reshape(assignments) + token_flat = jnp.arange(assignments, dtype=jnp.int32) // topk + + sort_idx = jnp.argsort(expert_flat, axis=0) + expert_sorted = jnp.take(expert_flat, sort_idx, axis=0) + token_sorted = jnp.take(token_flat, sort_idx, axis=0) + weight_sorted = jnp.take(weight_flat, sort_idx, axis=0).astype(x_local.dtype) + + local_experts = moe_w13_local.shape[0] + if num_experts % local_experts != 0: + raise ValueError( + f"num_experts={num_experts} must be divisible by local expert count={local_experts} in EP mode" + ) + + ep_size = num_experts // local_experts + local_capacity = int(math.ceil(capacity_factor * assignments / ep_size)) + local_capacity = max(local_experts, local_capacity) + + expert_axis = jax.lax.axis_index("expert") + expert_start = expert_axis * local_experts + expert_end = expert_start + local_experts + local_mask = jnp.logical_and(expert_sorted >= expert_start, expert_sorted < expert_end) + + local_idx = jnp.nonzero(local_mask, size=local_capacity, fill_value=0)[0] + local_count = jnp.sum(local_mask, dtype=jnp.int32) + dropped_local = jnp.maximum(local_count - local_capacity, 0) + valid = jnp.arange(local_capacity, dtype=jnp.int32) < local_count + valid_weight = valid.astype(jnp.float32) + + token_local = jnp.take(token_sorted, local_idx, axis=0) + expert_local = jnp.take(expert_sorted, local_idx, axis=0) - expert_start + weight_local = jnp.take(weight_sorted, local_idx, axis=0) + + x_take = jnp.take(x_global, token_local, axis=0) + x_dispatch = jnp.where(valid[:, None], x_take, jnp.zeros_like(x_take)) + weight_dispatch = jnp.where(valid, weight_local, jnp.zeros_like(weight_local)) + expert_local = jnp.where(valid, expert_local, 0) + + group_sizes = jnp.bincount(expert_local, weights=valid_weight, length=local_experts).astype(jnp.int32) + # `local_idx` pads by appending invalid rows at the end; keep GMM segment + # boundaries aligned by attributing padding to the final expert segment. + group_sizes = group_sizes.at[-1].add(local_capacity - jnp.sum(group_sizes, dtype=jnp.int32)) + + w13_out = gmm_sharded(x_dispatch, moe_w13_local, group_sizes) + moe_dim = moe_w2_local.shape[1] + gate, up = jnp.split(w13_out, [moe_dim], axis=-1) + out_dispatch = gmm_sharded(activation_fn(gate) * up, moe_w2_local, group_sizes) + + out_global = jnp.zeros_like(x_global).at[token_local].add(out_dispatch * weight_dispatch[:, None], mode="drop") + # #2710 ring EP strategy: collect only this shard's token slice after + # reducing contributions from experts across the EP mesh. + out_local = jax.lax.psum_scatter(out_global, "expert", scatter_dimension=0, tiled=True) + dropped_total = jax.lax.psum(dropped_local, ("data", "expert")) + return out_local, dropped_total + + +@named_call +def moe_mlp( + x: Float[Array, "T D"], + selected_experts: Int[Array, "T K"], + combine_weights: Float[Array, "T K"], + w_up_gate: Float[Array, "E D I2"], + w_down: Float[Array, "E I D"], + *, + activation: MoeActivation = ActivationFunctionEnum.silu, + mesh: jax.sharding.AbstractMesh | None = None, + capacity_factor: float = _DEFAULT_EP_CAPACITY_FACTOR, + report_capacity_overflow: bool = False, +) -> Float[Array, "T D"] | tuple[Float[Array, "T D"], Int[Array, ""]]: + """Functional routed MoE MLP core used by Grug modules and benchmarks. + + This helper handles dispatch/permute/unpermute (+EP collectives) from + precomputed token-to-expert assignments. Routing logits/top-k selection + stays in the caller (e.g. model MLP block). + + Set `report_capacity_overflow=True` to also return a scalar count of + dropped expert assignments from EP capacity clipping. + """ + if mesh is None: + mesh = get_abstract_mesh() + + if isinstance(activation, ActivationFunctionEnum): + activation_fn = activation.to_jax_fn() + else: + activation_fn = activation + + if x.ndim != 2: + raise ValueError(f"x must be rank-2 [T, D], got shape={x.shape}") + if selected_experts.ndim != 2: + raise ValueError(f"selected_experts must be rank-2 [T, K], got shape={selected_experts.shape}") + if selected_experts.shape != combine_weights.shape: + raise ValueError( + "selected_experts and combine_weights must have identical [T, K] shapes; " + f"got {selected_experts.shape} vs {combine_weights.shape}" + ) + if selected_experts.shape[0] != x.shape[0]: + raise ValueError( + f"selected_experts/combine_weights token dim ({selected_experts.shape[0]}) must match x token " + f"dim ({x.shape[0]})" + ) + + num_experts = int(w_up_gate.shape[0]) + if w_down.shape[0] != num_experts: + raise ValueError( + f"w_down expert dimension ({w_down.shape[0]}) must match w_up_gate expert dimension ({num_experts})" + ) + + has_expert_axis = _mesh_has_axis(mesh, "expert") + expert_axis_size = _mesh_axis_size(mesh, "expert") + + if mesh is None or mesh.empty: + out, dropped = _moe_mlp_local( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation_fn=activation_fn, + num_experts=num_experts, + ) + if report_capacity_overflow: + return out, dropped + return out + + batch_spec = _batch_spec_from_x(x, mesh) + local_expert_spec = P("expert", None, None) if has_expert_axis else P(None, None, None) + + if has_expert_axis and expert_axis_size > 1: + if num_experts % expert_axis_size != 0: + raise ValueError(f"num_experts={num_experts} must be divisible by expert axis size={expert_axis_size}") + + # #2710: prefer ring EP collectives when a real expert mesh is present. + shard_fn = shard_map( + partial( + _moe_mlp_ep_ring_local, + activation_fn=activation_fn, + num_experts=num_experts, + capacity_factor=capacity_factor, + ), + mesh=mesh, + in_specs=( + batch_spec, + batch_spec, + batch_spec, + P("expert", None, None), + P("expert", None, None), + ), + out_specs=(batch_spec, P()), + check_vma=False, + ) + out, dropped = shard_fn(x, selected_experts, combine_weights, w_up_gate, w_down) + if report_capacity_overflow: + return out, dropped + return out + + # Fallback path for no expert axis (or expert axis size 1) keeps routing + # semantics without EP collectives. + shard_fn = shard_map( + partial( + _moe_mlp_local, + activation_fn=activation_fn, + num_experts=num_experts, + ), + mesh=mesh, + in_specs=( + batch_spec, + batch_spec, + batch_spec, + local_expert_spec, + local_expert_spec, + ), + out_specs=(batch_spec, P()), + check_vma=False, + ) + out, dropped = shard_fn(x, selected_experts, combine_weights, w_up_gate, w_down) + if report_capacity_overflow: + return out, dropped + return out + + +__all__ = [ + "MoeActivation", + "moe_mlp", +] diff --git a/lib/levanter/src/levanter/grug/loss.py b/lib/levanter/src/levanter/grug/loss.py index 56705919ce..903f92b5d8 100644 --- a/lib/levanter/src/levanter/grug/loss.py +++ b/lib/levanter/src/levanter/grug/loss.py @@ -17,6 +17,35 @@ ) +def _batch_axis_spec(x: jax.Array): + x_type = jax.typeof(x) + sharding = getattr(x_type, "sharding", None) + spec = getattr(sharding, "spec", None) + if spec is not None and len(spec) > 0: + return spec[0] + sharding = getattr(x, "sharding", None) + spec = getattr(sharding, "spec", None) + if spec is not None and len(spec) > 0: + return spec[0] + return ("data",) + + +def _axis_names_from_spec(axis_spec) -> tuple[str, ...]: + if axis_spec is None: + return () + if isinstance(axis_spec, tuple): + return tuple(str(name) for name in axis_spec) + return (str(axis_spec),) + + +def _psum_over_axes(x: jax.Array, axis_names: tuple[str, ...]) -> jax.Array: + if len(axis_names) == 0: + return x + if len(axis_names) == 1: + return jax.lax.psum(x, axis_names[0]) + return jax.lax.psum(x, axis_names) + + @named_call def fused_linear_softmax_cross_entropy_loss( hidden: jax.Array, @@ -60,6 +89,8 @@ def fused_linear_softmax_cross_entropy_loss( raise ValueError(f"Unknown reduction: {reduction}") weight_array = weight if weight is not None else jnp.ones_like(labels, dtype=dtype) + batch_axis_spec = _batch_axis_spec(hidden) + batch_axis_names = _axis_names_from_spec(batch_axis_spec) def _loss_shard( shard_hidden: jax.Array, @@ -67,12 +98,9 @@ def _loss_shard( shard_labels: jax.Array, shard_weight: jax.Array, ) -> jax.Array: - print(f"hid sharding: {jax.typeof(shard_hidden)}") flat_hidden = shard_hidden.reshape((-1, hidden_dim)) flat_labels = shard_labels.reshape((-1,)).astype(jnp.int32) flat_weight = shard_weight.reshape((-1,)) - print(f"flat sharding: {jax.typeof(flat_hidden)}") - print(flat_hidden.shape, flat_labels.shape, flat_weight.shape) loss = fused_cross_entropy_loss_and_logsumexp_penalty( flat_hidden, @@ -93,16 +121,16 @@ def _loss_shard( local_sum = jnp.sum(loss) local_denom = jnp.sum(flat_weight) - total_sum = jax.lax.psum(local_sum, "data") + total_sum = _psum_over_axes(local_sum, batch_axis_names) if reduction_mode == "sum": return total_sum - total_denom = jax.lax.psum(local_denom, "data") + total_denom = _psum_over_axes(local_denom, batch_axis_names) return jnp.where(total_denom != 0, total_sum / total_denom, jnp.zeros_like(total_denom)) - out_specs = P(("data",)) if reduction_mode is None else P() + out_specs = P(batch_axis_spec) if reduction_mode is None else P() return jax.shard_map( _loss_shard, - in_specs=(P(("data",)), P(None, None), P(("data",)), P(("data",))), + in_specs=(P(batch_axis_spec), P(None, None), P(batch_axis_spec), P(batch_axis_spec)), out_specs=out_specs, check_vma=False, )(hidden, lm_head, labels, weight_array) diff --git a/lib/levanter/src/levanter/grug/sharding.py b/lib/levanter/src/levanter/grug/sharding.py index 8d57df5bb8..6e0a4a0ba7 100644 --- a/lib/levanter/src/levanter/grug/sharding.py +++ b/lib/levanter/src/levanter/grug/sharding.py @@ -12,4 +12,4 @@ def unshard(x: jax.Array) -> jax.Array: - return reshard(x, P((None,))) + return reshard(x, P(None)) diff --git a/lib/levanter/src/levanter/utils/activation.py b/lib/levanter/src/levanter/utils/activation.py index 201cc3551c..2bc83feaa4 100644 --- a/lib/levanter/src/levanter/utils/activation.py +++ b/lib/levanter/src/levanter/utils/activation.py @@ -6,6 +6,7 @@ from functools import partial import jax +import jax.numpy as jnp import haliax as hax import haliax.nn as hnn @@ -13,10 +14,20 @@ _A = typing.TypeVar("_A", hax.Scalar, hax.NamedArray, jax.Array) ActivationFunction = typing.Callable[[_A], _A] +JaxActivationFunction = typing.Callable[[jax.Array], jax.Array] + + +def _quick_gelu_jax(x: jax.Array) -> jax.Array: + return x * jax.nn.sigmoid(1.702 * x) + + +def _relu2_jax(x: jax.Array) -> jax.Array: + return jnp.square(jax.nn.relu(x)) class ActivationFunctionEnum(str, enum.Enum): relu = "relu" + relu2 = "relu2" silu = "silu" swish = "swish" gelu = "gelu" @@ -30,10 +41,16 @@ def to_fn(self) -> ActivationFunction: raise ValueError("xielu is parameterized; use XIELUActivation directly.") return TO_FN[self] + def to_jax_fn(self) -> JaxActivationFunction: + if self is ActivationFunctionEnum.xielu: + raise ValueError("xielu is parameterized; use XIELUActivation directly.") + return TO_JAX_FN[self] + # type: ignore TO_FN: dict[ActivationFunctionEnum, ActivationFunction] = { ActivationFunctionEnum.relu: hnn.relu, + ActivationFunctionEnum.relu2: hnn.relu_squared, ActivationFunctionEnum.silu: hnn.silu, ActivationFunctionEnum.swish: hnn.swish, ActivationFunctionEnum.gelu: partial(hnn.gelu, approximate=False), @@ -41,3 +58,15 @@ def to_fn(self) -> ActivationFunction: ActivationFunctionEnum.quick_gelu: hnn.quick_gelu, ActivationFunctionEnum.tanh: hax.tanh, } + + +TO_JAX_FN: dict[ActivationFunctionEnum, JaxActivationFunction] = { + ActivationFunctionEnum.relu: jax.nn.relu, + ActivationFunctionEnum.relu2: _relu2_jax, + ActivationFunctionEnum.silu: jax.nn.silu, + ActivationFunctionEnum.swish: jax.nn.swish, + ActivationFunctionEnum.gelu: partial(jax.nn.gelu, approximate=False), + ActivationFunctionEnum.gelu_new: partial(jax.nn.gelu, approximate=True), + ActivationFunctionEnum.quick_gelu: _quick_gelu_jax, + ActivationFunctionEnum.tanh: jnp.tanh, +} diff --git a/lib/levanter/tests/grug/test_grugformer_moe.py b/lib/levanter/tests/grug/test_grugformer_moe.py new file mode 100644 index 0000000000..7a39e53bab --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer_moe.py @@ -0,0 +1,292 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +import jax +import jax.numpy as jnp +from jax._src import config as jax_config +from jax.sharding import AbstractMesh, AxisType, Mesh, NamedSharding, PartitionSpec as P, use_abstract_mesh + +from levanter.grug.grug_moe import moe_mlp +from levanter.utils.activation import ActivationFunctionEnum + + +def _make_dense_mesh() -> Mesh: + devices = jax.devices() + if not devices: + raise RuntimeError("No JAX devices available") + mesh_devices = np.array(devices).reshape(len(devices), 1) + return Mesh( + mesh_devices, + axis_names=("data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def _make_ep_mesh_or_none() -> Mesh | None: + devices = jax.devices() + if len(devices) < 2 or len(devices) % 2 != 0: + return None + mesh_devices = np.array(devices).reshape(len(devices) // 2, 2, 1) + return Mesh( + mesh_devices, + axis_names=("data", "expert", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), + ) + + +def _make_abstract_moe_mesh(*, data: int, expert: int, model: int) -> AbstractMesh: + return AbstractMesh( + axis_sizes=(data, expert, model), + axis_names=("data", "expert", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), + ) + + +class _reset_abstract_mesh: + def __enter__(self): + self._prev = jax_config.abstract_mesh_context_manager.swap_local(jax_config.config_ext.unset) + return self + + def __exit__(self, exc_type, exc, tb): + jax_config.abstract_mesh_context_manager.set_local(self._prev) + return False + + +def _make_inputs( + *, + key: jax.Array, + tokens: int, + hidden_dim: int, + intermediate_dim: int, + num_experts: int, + topk: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + k_x, k_sel, k_logits, k_w13, k_w2 = jax.random.split(key, 5) + x = jax.random.normal(k_x, (tokens, hidden_dim), dtype=jnp.float32) + selected_experts = jax.random.randint(k_sel, (tokens, topk), 0, num_experts, dtype=jnp.int32) + combine_logits = jax.random.normal(k_logits, (tokens, topk), dtype=jnp.float32) + combine_weights = jax.nn.softmax(combine_logits, axis=-1) + w_up_gate = jax.random.normal(k_w13, (num_experts, hidden_dim, 2 * intermediate_dim), dtype=jnp.float32) + w_down = jax.random.normal(k_w2, (num_experts, intermediate_dim, hidden_dim), dtype=jnp.float32) + return x, selected_experts, combine_weights, w_up_gate, w_down + + +def test_moe_mlp_runs_without_ep_axis(): + mesh = _make_dense_mesh() + tokens = max(8, len(jax.devices()) * 8) + hidden_dim = 32 + intermediate_dim = 64 + num_experts = 4 + topk = 2 + + with jax.set_mesh(mesh): + x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs( + key=jax.random.key(0), + tokens=tokens, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + topk=topk, + ) + + out = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=ActivationFunctionEnum.silu, + mesh=None, + ) + assert out.shape == (tokens, hidden_dim) + assert jnp.isfinite(out).all() + + jit_fn = jax.jit( + lambda x, sel, cw, up_gate, down: moe_mlp( + x, sel, cw, up_gate, down, activation=ActivationFunctionEnum.silu, mesh=None + ) + ) + out_jit = jit_fn(x, selected_experts, combine_weights, w_up_gate, w_down) + np.testing.assert_allclose(np.asarray(out), np.asarray(out_jit), rtol=1e-5, atol=1e-5) + + +def test_moe_ring_ep_path_lowers_on_abstract_mesh(): + mesh = _make_abstract_moe_mesh(data=2, expert=2, model=1) + + tokens = 16 + hidden_dim = 32 + intermediate_dim = 64 + num_experts = 4 + topk = 2 + + with _reset_abstract_mesh(), use_abstract_mesh(mesh): + x = jax.ShapeDtypeStruct( + shape=(tokens, hidden_dim), + dtype=jnp.float32, + sharding=NamedSharding(mesh, P(("data", "expert"), None)), + ) + selected_experts = jax.ShapeDtypeStruct( + shape=(tokens, topk), + dtype=jnp.int32, + sharding=NamedSharding(mesh, P(("data", "expert"), None)), + ) + combine_weights = jax.ShapeDtypeStruct( + shape=(tokens, topk), + dtype=jnp.float32, + sharding=NamedSharding(mesh, P(("data", "expert"), None)), + ) + w_up_gate = jax.ShapeDtypeStruct( + shape=(num_experts, hidden_dim, 2 * intermediate_dim), + dtype=jnp.float32, + sharding=NamedSharding(mesh, P("expert", None, None)), + ) + w_down = jax.ShapeDtypeStruct( + shape=(num_experts, intermediate_dim, hidden_dim), + dtype=jnp.float32, + sharding=NamedSharding(mesh, P("expert", None, None)), + ) + + def f(x, sel, cw, up_gate, down): + return moe_mlp( + x, + sel, + cw, + up_gate, + down, + activation=ActivationFunctionEnum.silu, + mesh=mesh, + ) + + platform = jax.devices()[0].platform if jax.devices() else jax.default_backend() + lowered = ( + jax.jit(f) + .trace(x, selected_experts, combine_weights, w_up_gate, w_down) + .lower(lowering_platforms=(platform,)) + ) + assert lowered is not None + + +def test_moe_mlp_runs_with_ep_axis_when_available(): + mesh = _make_ep_mesh_or_none() + if mesh is None: + pytest.skip("requires an even number of >=2 devices") + + tokens = len(jax.devices()) * 8 + hidden_dim = 32 + intermediate_dim = 64 + num_experts = 4 + topk = 2 + + with jax.set_mesh(mesh): + x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs( + key=jax.random.key(1), + tokens=tokens, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + topk=topk, + ) + + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) + expert_sharding = NamedSharding(mesh, P("expert", None, None)) + x = jax.sharding.reshard(x, batch_sharding) + selected_experts = jax.sharding.reshard(selected_experts, batch_sharding) + combine_weights = jax.sharding.reshard(combine_weights, batch_sharding) + w_up_gate = jax.sharding.reshard(w_up_gate, expert_sharding) + w_down = jax.sharding.reshard(w_down, expert_sharding) + + out = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=ActivationFunctionEnum.silu, + mesh=None, + ) + assert out.shape == (tokens, hidden_dim) + assert jnp.isfinite(out).all() + + +def test_functional_moe_mlp_accepts_enum_and_callable_activation(): + tokens = 16 + hidden_dim = 16 + intermediate_dim = 24 + num_experts = 8 + topk = 2 + + x, selected_experts, combine_weights, w_up_gate, w_down = _make_inputs( + key=jax.random.key(2), + tokens=tokens, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_experts=num_experts, + topk=topk, + ) + + y_enum = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=ActivationFunctionEnum.silu, + mesh=None, + ) + y_callable = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + activation=lambda t: jax.nn.silu(t), + mesh=None, + ) + np.testing.assert_allclose(np.asarray(y_callable), np.asarray(y_enum), rtol=1e-5, atol=1e-5) + + +def test_moe_mlp_reports_positive_drop_count_in_ep_when_over_capacity(): + mesh = _make_ep_mesh_or_none() + if mesh is None: + pytest.skip("requires an even number of >=2 devices") + + tokens = len(jax.devices()) * 8 + hidden_dim = 16 + intermediate_dim = 24 + num_experts = 4 + topk = 2 + + key = jax.random.key(5) + x = jax.random.normal(key, (tokens, hidden_dim), dtype=jnp.float32) + selected_experts = jnp.zeros((tokens, topk), dtype=jnp.int32) + combine_weights = jnp.full((tokens, topk), 0.5, dtype=jnp.float32) + w_up_gate = jax.random.normal( + jax.random.key(6), (num_experts, hidden_dim, 2 * intermediate_dim), dtype=jnp.float32 + ) + w_down = jax.random.normal(jax.random.key(7), (num_experts, intermediate_dim, hidden_dim), dtype=jnp.float32) + + with jax.set_mesh(mesh): + batch_sharding = NamedSharding(mesh, P(("data", "expert"), None)) + expert_sharding = NamedSharding(mesh, P("expert", None, None)) + x = jax.sharding.reshard(x, batch_sharding) + selected_experts = jax.sharding.reshard(selected_experts, batch_sharding) + combine_weights = jax.sharding.reshard(combine_weights, batch_sharding) + w_up_gate = jax.sharding.reshard(w_up_gate, expert_sharding) + w_down = jax.sharding.reshard(w_down, expert_sharding) + + out, dropped = moe_mlp( + x, + selected_experts, + combine_weights, + w_up_gate, + w_down, + mesh=None, + report_capacity_overflow=True, + ) + + assert out.shape == (tokens, hidden_dim) + assert dropped.shape == () + assert int(dropped) > 0 diff --git a/tests/test_grug_base_template.py b/tests/test_grug_base_template.py deleted file mode 100644 index 02f477cbb1..0000000000 --- a/tests/test_grug_base_template.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2025 The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -import json -import logging -import tempfile -import uuid -from io import StringIO -from pathlib import Path - -import equinox as eqx -import jax -import jax.numpy as jnp -import jmp -import optax -import pytest -from jax._src import config as jax_config -from jax.sharding import AbstractMesh, AxisType, NamedSharding, PartitionSpec as P, use_abstract_mesh - -from levanter.callbacks.watch import WatchConfig -from levanter.checkpoint import CheckpointerConfig -from levanter.data.dataset import ListAsyncDataset -from levanter.data.text import DirectDatasetComponent, LmDataConfig -from levanter.data.text.examples import GrugLmExample -from levanter.distributed import DistributedConfig, RayConfig -from levanter.grug.attention import AttentionMask as GrugAttentionMask -from levanter.grug.sharding import Pbatch -from levanter.tracker.json_logger import JsonLoggerConfig -from levanter.trainer import TrainerConfig - -from experiments.grug.base.model import GrugModelConfig, Transformer -from experiments.grug.base.train import ( - GrugRunConfig, - GrugTrainerConfig, - GrugTrainState, - _make_train_step, - run_grug, -) - - -def _make_abstract_mesh(*, data: int, model: int) -> AbstractMesh: - return AbstractMesh( - axis_sizes=(data, model), - axis_names=("data", "model"), - axis_types=(AxisType.Explicit, AxisType.Explicit), - ) - - -class _reset_abstract_mesh: - def __enter__(self): - self._prev = jax_config.abstract_mesh_context_manager.swap_local(jax_config.config_ext.unset) - return self - - def __exit__(self, exc_type, exc, tb): - jax_config.abstract_mesh_context_manager.set_local(self._prev) - return False - - -class DummyModel(eqx.Module): - w: jax.Array - - def compute_next_token_loss( - self, - token_ids: jax.Array, - loss_weight: jax.Array, - *, - mask=None, - reduction: str = "mean", - logsumexp_weight: float | None = None, - ) -> jax.Array: - del token_ids, loss_weight, mask, reduction, logsumexp_weight - return jnp.mean(jnp.square(self.w)) - - -def _build_state(params: DummyModel, optimizer: optax.GradientTransformation) -> GrugTrainState: - return GrugTrainState( - step=jnp.array(0, dtype=jnp.int32), - params=params, - opt_state=optimizer.init(params), - ema_params=params, - ) - - -def test_grug_base_train_step_with_watch_matches_base_step(): - optimizer = optax.adam(1e-2) - mp = jmp.get_policy("f32") - - state_for_base = _build_state(DummyModel(jnp.array([1.0, -2.0], dtype=jnp.float32)), optimizer) - state_for_watch = _build_state(DummyModel(jnp.array([1.0, -2.0], dtype=jnp.float32)), optimizer) - batch = GrugLmExample( - tokens=jnp.zeros((1, 4), dtype=jnp.int32), - loss_weight=jnp.ones((1, 4), dtype=jnp.float32), - attn_mask=GrugAttentionMask.causal(), - ) - - base_step = _make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None) - watch_step = _make_train_step( - optimizer, - mp, - z_loss_weight=0.0, - ema_beta=None, - watch_config=WatchConfig( - watch_targets=["grads", "params", "updates"], - include_norms=True, - include_per_parameter_norms=True, - include_histograms=False, - split_scan_layers=True, - interval=1, - ), - ) - - next_base, metrics_base, base_watch_stats = base_step(state_for_base, batch, compute_watch=False) - next_watch, metrics_watch, watch_stats = watch_step(state_for_watch, batch, compute_watch=True) - - assert int(next_base.step) == 1 - assert int(next_watch.step) == 1 - assert jnp.allclose(next_base.params.w, next_watch.params.w) - assert jnp.allclose(next_base.ema_params.w, next_watch.ema_params.w) - assert jnp.allclose(metrics_base["train/loss"], metrics_watch["train/loss"]) - assert base_watch_stats is None - assert watch_stats - assert any(key.startswith("grad/") for key in watch_stats) - assert any(key.startswith("params/") for key in watch_stats) - assert any(key.startswith("updates/") for key in watch_stats) - - -def test_grug_base_run_emits_expected_metrics_with_json_tracker(): - vocab_size = 128 - seq_len = 32 - examples = [] - for i in range(8): - tokens = (jnp.arange(seq_len, dtype=jnp.int32) + i) % vocab_size - examples.append(GrugLmExample.causal(tokens)) - - dataset = ListAsyncDataset(examples) - data_config = LmDataConfig( - components={"direct": DirectDatasetComponent(datasets={"train": dataset})}, - vocab_size=vocab_size, - tokenizer="passthrough", - ) - - logger_name = f"test_grug_base_json_tracker_{uuid.uuid4().hex}" - stream = StringIO() - handler = logging.StreamHandler(stream) - logger = logging.getLogger(logger_name) - logger.handlers.clear() - logger.propagate = False - logger.addHandler(handler) - logger.setLevel(logging.INFO) - - with tempfile.TemporaryDirectory() as tmpdir: - trainer_config = TrainerConfig( - id="test-grug-base-metrics", - num_train_steps=1, - train_batch_size=max(1, len(jax.devices())), - tracker=JsonLoggerConfig(logger_name=logger_name), - require_accelerator=False, - use_explicit_mesh_axes=True, - distributed=DistributedConfig(initialize_jax_distributed=False), - ray=RayConfig(auto_start_cluster=False), - log_dir=Path(tmpdir) / "logs", - checkpointer=CheckpointerConfig(base_path=str(Path(tmpdir) / "checkpoints")), - ) - - run_grug( - GrugRunConfig( - model=GrugModelConfig( - vocab_size=vocab_size, - hidden_dim=32, - intermediate_dim=64, - num_layers=2, - num_heads=2, - num_kv_heads=2, - max_seq_len=seq_len, - ), - data=data_config, - trainer=GrugTrainerConfig(trainer=trainer_config, log_every=1), - eval=None, - ) - ) - - logger.removeHandler(handler) - records = [json.loads(line) for line in stream.getvalue().splitlines() if line.strip()] - finish_records = [record for record in records if record.get("event") == "finish"] - assert len(finish_records) == 1 - summary = finish_records[0]["summary"] - - required_keys = [ - "train/loss", - "global_step", - "throughput/duration", - "throughput/hook_time", - "throughput/loading_time", - "throughput/total_tokens", - "throughput/examples_per_second", - "throughput/tokens_per_second", - "throughput/flops_per_example_analytic", - ] - for key in required_keys: - assert key in summary - - -@pytest.mark.parametrize(("data", "model"), [(4, 1), (2, 2)]) -def test_grug_base_loss_lowers_on_abstract_4_device_mesh(data: int, model: int): - seq = 256 if jax.default_backend() == "tpu" else 16 - cfg = GrugModelConfig( - vocab_size=256, - hidden_dim=128, - intermediate_dim=256, - num_layers=1, - num_heads=8, - num_kv_heads=8, - max_seq_len=seq, - ) - - mesh = _make_abstract_mesh(data=data, model=model) - # Some test setups leave an abstract mesh context active; reset before setting a new size. - with _reset_abstract_mesh(), use_abstract_mesh(mesh): - key = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.uint32, sharding=NamedSharding(mesh, P())) - params = jax.eval_shape(lambda k: Transformer.init(cfg, key=k), key) - - def loss_fn(p): - token_ids = jnp.zeros((8, seq), dtype=jnp.int32) - token_ids = jax.sharding.reshard(token_ids, Pbatch) - loss_weight = jnp.ones((8, seq), dtype=jnp.float32) - loss_weight = jax.sharding.reshard(loss_weight, Pbatch) - return p.compute_next_token_loss(token_ids, loss_weight, mask=GrugAttentionMask.causal(), reduction="mean") - - platform = jax.devices()[0].platform if jax.devices() else jax.default_backend() - lowered = jax.jit(loss_fn).trace(params).lower(lowering_platforms=(platform,)) - - assert lowered is not None diff --git a/tests/test_grug_variant_contracts.py b/tests/test_grug_variant_contracts.py new file mode 100644 index 0000000000..63f9e02e9d --- /dev/null +++ b/tests/test_grug_variant_contracts.py @@ -0,0 +1,264 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Contract tests for grug variants under experiments/grug/*. + +These checks are intentionally variant-discovered: if a subdirectory contains +`model.py` and/or `train.py`, it is expected to satisfy the corresponding +lowering and training contracts. +""" + +import dataclasses +import importlib +import json +import logging +import uuid +from io import StringIO + +from pathlib import Path + +import equinox as eqx +import jax +import jax.numpy as jnp +import jmp +import optax +import pytest +from jax._src import config as jax_config +from jax.sharding import NamedSharding, PartitionSpec as P, use_abstract_mesh + +from levanter.checkpoint import CheckpointerConfig +from levanter.data.dataset import ListAsyncDataset +from levanter.data.text import DirectDatasetComponent, LmDataConfig +from levanter.data.text.examples import GrugLmExample +from levanter.distributed import DistributedConfig, RayConfig +from levanter.grug.attention import AttentionMask as GrugAttentionMask +from levanter.tracker.json_logger import JsonLoggerConfig +from levanter.trainer import TrainerConfig + + +def _discover_grug_variants_with_file(filename: str) -> list[str]: + grug_dir = Path(__file__).resolve().parents[1] / "experiments" / "grug" + variants: list[str] = [] + found_any = False + for child in sorted(grug_dir.iterdir()): + if not child.is_dir() or child.name.startswith("__"): + continue + if (child / filename).is_file(): + found_any = True + if _variant_has_noverify(child): + continue + variants.append(child.name) + if not variants and not found_any: + raise AssertionError(f"No grug variants with {filename} found under {grug_dir}") + return variants + + +def _variant_module_name(variant: str, module: str) -> str: + return f"experiments.grug.{variant}.{module}" + + +def _variant_has_noverify(variant_dir: Path) -> bool: + train_file = variant_dir / "train.py" + if not train_file.is_file(): + return False + return "# GRUG NOVERIFY" in train_file.read_text(encoding="utf-8") + + +class _reset_abstract_mesh: + def __enter__(self): + self._prev = jax_config.abstract_mesh_context_manager.swap_local(jax_config.config_ext.unset) + return self + + def __exit__(self, exc_type, exc, tb): + jax_config.abstract_mesh_context_manager.set_local(self._prev) + return False + + +def _discover_grug_variants_with_model_and_train() -> list[str]: + model_variants = set(_discover_grug_variants_with_file("model.py")) + train_variants = set(_discover_grug_variants_with_file("train.py")) + variants = sorted(model_variants & train_variants) + if not variants and model_variants and train_variants: + return [] + if not variants: + raise AssertionError("No grug variants with both model.py and train.py found") + return variants + + +@pytest.mark.parametrize( + "variant", + _discover_grug_variants_with_file("model.py"), +) +def test_grug_variant_loss_lowers_on_abstract_mesh(variant: str): + module_name = _variant_module_name(variant, "model") + module = importlib.import_module(module_name) + config_cls = module.GrugModelConfig + transformer_cls = module.Transformer + + seq = 256 if jax.default_backend() == "tpu" else 16 + cfg = config_cls(vocab_size=256, max_seq_len=seq) + mesh_fn = getattr(module, "debug_mesh_and_token_pspec", None) + if mesh_fn is None: + raise AssertionError(f"{module_name} must define debug_mesh_and_token_pspec(num_devices)") + mesh, token_pspec = mesh_fn(num_devices=4) + + with _reset_abstract_mesh(), use_abstract_mesh(mesh): + key = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.uint32, sharding=NamedSharding(mesh, P())) + + def init_model(k): + return transformer_cls.init(cfg, key=k) + + params = jax.eval_shape(init_model, key) + if not hasattr(params, "next_token_loss"): + raise AssertionError(f"{module_name}.Transformer must define next_token_loss") + + def loss_fn(p): + token_ids = jnp.zeros((8, seq), dtype=jnp.int32) + token_ids = jax.sharding.reshard(token_ids, token_pspec) + loss_weight = jnp.ones((8, seq), dtype=jnp.float32) + loss_weight = jax.sharding.reshard(loss_weight, token_pspec) + return p.next_token_loss( + token_ids, + loss_weight, + mask=GrugAttentionMask.causal(), + reduction="mean", + ) + + platform = jax.devices()[0].platform if jax.devices() else jax.default_backend() + lowered = jax.jit(loss_fn).trace(params).lower(lowering_platforms=(platform,)) + + assert lowered is not None + + +def _small_model_config(model_config_cls, *, vocab_size: int, seq_len: int): + base_kwargs = { + "vocab_size": vocab_size, + "hidden_dim": 32, + "intermediate_dim": 64, + "num_layers": 2, + "num_heads": 2, + "num_kv_heads": 2, + "max_seq_len": seq_len, + "num_experts": 4, + "num_experts_per_token": 2, + "shared_expert_intermediate_dim": 64, + } + field_names = {field.name for field in dataclasses.fields(model_config_cls)} + kwargs = {k: v for k, v in base_kwargs.items() if k in field_names} + return model_config_cls(**kwargs) + + +@pytest.mark.parametrize( + "variant", + _discover_grug_variants_with_model_and_train(), +) +def test_grug_variant_one_step_contract_lowers_with_default_ctor(variant: str): + train_module = importlib.import_module(_variant_module_name(variant, "train")) + model_module = importlib.import_module(_variant_module_name(variant, "model")) + model_config_cls = model_module.GrugModelConfig + make_train_step = train_module._make_train_step + initial_state = train_module.initial_state + mesh_fn = getattr(model_module, "debug_mesh_and_token_pspec", None) + if mesh_fn is None: + raise AssertionError(f"{_variant_module_name(variant, 'model')} must define debug_mesh_and_token_pspec") + + cfg = model_config_cls(vocab_size=1024) + optimizer = optax.adam(1e-2) + mp = jmp.get_policy("f32") + train_step = make_train_step(optimizer, mp, z_loss_weight=0.0, ema_beta=None) + mesh, token_pspec = mesh_fn(num_devices=4) + batch = GrugLmExample( + tokens=jnp.zeros((8, 4), dtype=jnp.int32), + loss_weight=jnp.ones((8, 4), dtype=jnp.float32), + attn_mask=GrugAttentionMask.causal(), + ) + + def one_step(): + sharded_batch = dataclasses.replace( + batch, + tokens=jax.sharding.reshard(batch.tokens, token_pspec), + loss_weight=jax.sharding.reshard(batch.loss_weight, token_pspec), + ) + state = initial_state(cfg, optimizer=optimizer, mp=mp, key=jax.random.PRNGKey(0)) + return train_step(state, sharded_batch, compute_watch=False) + + with _reset_abstract_mesh(), use_abstract_mesh(mesh): + out_state_shape, out_metrics_shape, out_watch_shape = eqx.filter_eval_shape(one_step) + + assert out_state_shape.step.shape == () + assert "train/loss" in out_metrics_shape + assert out_metrics_shape["train/loss"].shape == () + assert out_watch_shape is None + + +def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path): + train_module = importlib.import_module("experiments.grug.base.train") + model_module = importlib.import_module("experiments.grug.base.model") + + vocab_size = 128 + seq_len = 32 + examples = [] + for i in range(8): + tokens = (jnp.arange(seq_len, dtype=jnp.int32) + i) % vocab_size + examples.append(GrugLmExample.causal(tokens)) + + dataset = ListAsyncDataset(examples) + data_config = LmDataConfig( + components={"direct": DirectDatasetComponent(datasets={"train": dataset})}, + vocab_size=vocab_size, + tokenizer="passthrough", + ) + + logger_name = f"test_grug_json_tracker_base_{uuid.uuid4().hex}" + stream = StringIO() + handler = logging.StreamHandler(stream) + logger = logging.getLogger(logger_name) + logger.handlers.clear() + logger.propagate = False + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + try: + variant_tmp = tmp_path / "base" + variant_tmp.mkdir(parents=True, exist_ok=True) + trainer_config = TrainerConfig( + id="test-grug-base-metrics", + num_train_steps=1, + train_batch_size=max(1, len(jax.devices())), + tracker=JsonLoggerConfig(logger_name=logger_name), + require_accelerator=False, + use_explicit_mesh_axes=True, + distributed=DistributedConfig(initialize_jax_distributed=False), + ray=RayConfig(auto_start_cluster=False), + log_dir=variant_tmp / "logs", + checkpointer=CheckpointerConfig(base_path=str(variant_tmp / "checkpoints")), + ) + + run_cfg = train_module.GrugRunConfig( + model=_small_model_config(model_module.GrugModelConfig, vocab_size=vocab_size, seq_len=seq_len), + data=data_config, + trainer=train_module.GrugTrainerConfig(trainer=trainer_config, log_every=1), + eval=None, + ) + train_module.run_grug(run_cfg) + finally: + logger.removeHandler(handler) + + records = [json.loads(line) for line in stream.getvalue().splitlines() if line.strip()] + finish_records = [record for record in records if record.get("event") == "finish"] + assert len(finish_records) == 1 + summary = finish_records[0]["summary"] + + required_keys = [ + "train/loss", + "global_step", + "throughput/duration", + "throughput/hook_time", + "throughput/loading_time", + "throughput/total_tokens", + "throughput/examples_per_second", + "throughput/tokens_per_second", + "throughput/flops_per_example_analytic", + ] + for key in required_keys: + assert key in summary