From 4b505af719e345ced88f9f2d8134ef29bba560f7 Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 19 Feb 2026 11:35:06 -0800 Subject: [PATCH 1/7] Sync OLMoE/Mixtral sweep relaunch and W&B MoE logging defaults --- AGENTS.md | 7 + experiments/defaults.py | 13 +- experiments/speedrun/custom_mixtral.py | 389 +++++-- .../speedrun/grugformer_moe/grugformer_moe.py | 956 ++++++++++++++++++ ...rugformer_moe_nemotron_dclm_fineweb_10b.py | 449 ++++++++ ...vs_dense_nemotron_dclm_fineweb_edu_100b.py | 836 +++++++++++++++ .../speedrun/olmoe_1b7b_nemotron_40b.py | 102 +- ...oe_m_nemotron_dclm_fineweb_40b_lr_sweep.py | 409 ++++++++ experiments/speedrun/prebuilt_caches.py | 6 +- lib/fray/src/fray/v1/cluster/ray/deps.py | 5 +- lib/fray/src/fray/v2/ray_backend/deps.py | 5 +- lib/haliax/src/haliax/nn/normalization.py | 19 +- lib/levanter/docs/Performance-Guide.md | 57 ++ lib/levanter/src/levanter/grad_accum.py | 29 +- lib/levanter/src/levanter/grug/sharding.py | 9 + .../fused_cross_entropy_loss/pallas_tpu.py | 20 + lib/levanter/src/levanter/utils/jax_utils.py | 176 +++- 17 files changed, 3338 insertions(+), 149 deletions(-) create mode 100644 experiments/speedrun/grugformer_moe/grugformer_moe.py create mode 100644 experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py create mode 100644 experiments/speedrun/mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b.py create mode 100644 experiments/speedrun/olmoe_m_nemotron_dclm_fineweb_40b_lr_sweep.py diff --git a/AGENTS.md b/AGENTS.md index 5fa800b09f..069958b95e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -117,3 +117,10 @@ DO NOT: ## Environment - Prefer to use `uv` when possible. If you can't (for instance, due to sandbox restrictions) you can use `.venv/bin/python` + +## Ray Run Notes + +- In shared clusters, `OwnerDiedError`, raylet "missed too many heartbeats" messages, and autoscaler resize logs can be noise from other workloads. Prefer judging health by your *job's* step logs/status and whether `run_on_pod_ray` is retrying preemptions. +- If a job appears stuck before logging metrics, it is often waiting on TPU slice scheduling (e.g. `SliceActor`/`TPUHostActor` pending creation). Use `uv run python scripts/ray/cluster.py --config infra/marin-us-central1.yaml list-jobs` and `... job-logs ` to confirm. +- Grugformer MoE smoke runs: prefer `--smoke --dataset nemotron_cc --tpu-type v5p-16 --seq-len 1024 --global-batch-size 32 --num-train-steps 20 --dataset-tokenizer meta-llama/Meta-Llama-3.1-8B --legacy-axis-resources`. The launcher defaults to fused (Pallas) CE; `xla` CE will materialize full logits and can OOM at realistic token counts. +- Grugformer MoE experts: default to the Megablox GMM pathway (`--use-gmm`). The ragged-dot pathway can trigger huge HBM temporaries during compile (e.g. expert-linear shapes like `bf16[64,262144,1024]`) and crash TPU workers; use `--no-use-gmm` only for debugging/ablations. diff --git a/experiments/defaults.py b/experiments/defaults.py index b096946aa9..9b609e5186 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -285,7 +285,10 @@ def default_train( eval_harness_tasks: Sequence[EvalTaskConfig] = CORE_TASKS, wandb_name: str | None = None, wandb_group: str | None = None, + wandb_project: str | None = None, override_output_path: str | None = None, + checkpointer_save_interval: timedelta | None = None, + checkpointer_keep: list[dict] | None = None, ) -> ExecutorStep: """ Train a language model using the default configuration. @@ -300,6 +303,10 @@ def default_train( eval_harness_tasks: List of evaluation harness tasks. Defaults to the CORE set of tasks. Use () or [] to disable wandb_name: Optional W&B display name for this run. Defaults to W&B's auto-generated name. wandb_group: Optional W&B group to organize related runs (e.g., a sweep). If unset, defaults to $WANDB_GROUP. + wandb_project: Optional W&B project name. Defaults to "marin" when unset. + checkpointer_save_interval: Optional override for the checkpointer time-based save interval. + checkpointer_keep: Optional override for the checkpointer step-based keep policies. Passing an empty list keeps + only time-based (temporary) checkpoints. """ pretraining_data = _prepare_data_config(tokenized, use_default_validation) @@ -346,7 +353,7 @@ def default_train( data=pretraining_data, trainer=TrainerConfig( tracker=WandbConfig( - project="marin", + project=wandb_project or "marin", name=wandb_name, tags=[*tags], group=wandb_group, @@ -358,8 +365,8 @@ def default_train( num_train_steps=train_config.num_train_steps, steps_per_eval=train_config.steps_per_eval if train_config.steps_per_eval is not None else 1000, checkpointer=CheckpointerConfig( - save_interval=timedelta(minutes=10), - keep=[dict(every=steps_per_export)], + save_interval=checkpointer_save_interval or timedelta(minutes=10), + keep=checkpointer_keep if checkpointer_keep is not None else [dict(every=steps_per_export)], ), model_averaging=model_averaging, mesh=MeshConfig( diff --git a/experiments/speedrun/custom_mixtral.py b/experiments/speedrun/custom_mixtral.py index 39bb243bd0..441e179fa7 100644 --- a/experiments/speedrun/custom_mixtral.py +++ b/experiments/speedrun/custom_mixtral.py @@ -7,8 +7,8 @@ import logging import os from dataclasses import dataclass -from functools import partial from collections.abc import Callable +from functools import partial import equinox as eqx import jax @@ -27,9 +27,11 @@ import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.layers.attention import Attention, AttentionBackend, AttentionConfig, AttentionMask +from levanter.layers.normalization import LayerNormConfigBase, RmsNormConfig from levanter.layers.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig from levanter.models.llama import LlamaEmbedding, LlamaMlp -from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.loss import maybe_fused_next_token_loss +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.models.mistral import MistralConfig from levanter.utils.activation import ActivationFunctionEnum from levanter.utils.flop_utils import lm_flops_per_token @@ -56,11 +58,11 @@ def _log_libtpu_args_once(): @LmConfig.register_subclass("custom_mixtral") @dataclass(frozen=True) -class MixtralConfig(MistralConfig): +class CustomMixtralConfig(MistralConfig): """Config for MistralModel Args: - seq_len (int, optional): maximum length of the input sequence. Defaults to 8192. + seq_len (int, optional): model context length / attention window. Defaults to 4096. hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 14336. num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. @@ -76,7 +78,10 @@ class MixtralConfig(MistralConfig): rzl_coef (`float`, optional): aux loss factor for router z-loss. Defaults to 0.001 """ - seq_len: int = 8192 + # The base `MistralConfig.max_seq_len` is used by some training codepaths for sanity checks. + # We override the default here to match our typical Mixtral runs (seq_len=4096). + max_seq_len: int = 4096 + seq_len: int = 4096 hidden_dim: int = 4096 intermediate_dim: int = 14336 num_layers: int = 32 @@ -95,7 +100,14 @@ class MixtralConfig(MistralConfig): rzl_coef: float | None = 0.001 # MoE optimization config - use_gmm: bool = False + use_gmm: bool = True + + # Expert parallelism (experimental) + expert_parallelism: int | None = None + # Which mesh axis to use for expert-parallel all-to-all. Defaults to the standard DP axis. + expert_mesh_axis: str = "data" + expert_logical_axis: str = "experts" + enable_expert_dispatch: bool = False # Attention-related config upcast_attn: bool = False @@ -103,6 +115,39 @@ class MixtralConfig(MistralConfig): attn_backend: AttentionBackend | None = None flash_attention_block_size: int | None = None + # Cross-entropy (next-token loss) optimization config. + cross_entropy_block_size: int | None = 1024 + cross_entropy_b_block_size: int | None = None + cross_entropy_h_block_size: int | None = None + cross_entropy_implementation: str | None = "legacy" + log_moe_metrics: bool = True + + # QK normalization (applies LayerNorm/RMSNorm in attention on q and k vectors). + use_qk_norm: bool = False + qk_norm: LayerNormConfigBase | None = None + qk_norm_eps: float = 1e-6 + + # Router/routing knobs. + # + # If False (default): compute full softmax(router_logits) and renormalize top-k probs. + # If True: select top-k experts, then compute a softmax over the selected logits only. + router_topk_then_softmax: bool = False + # If True: compute router/gating math (logits, selection logits, top-k, softmax, aux terms) in fp32. + # This is intended as a stability knob for MoE sweeps; the rest of the model stays in its normal dtype. + router_fp32: bool = False + + # DeepSeek-style "auxiliary-free load balancing" (ALF-LB). + # + # We add a per-layer `router_bias` used only for expert selection, and optionally add an aux term + # that produces gradients only for `router_bias` (via stop_gradient), approximating the paper's + # bias updates without coupling the main model parameters to an auxiliary loss. + alf_lb_loss_scale: float = 0.0 + alf_lb_use_sign: bool = True + alf_lb_center_bias: bool = True + + # Use dense (single-expert) routing for the first N transformer layers. + dense_first_n_layers: int = 0 + gradient_checkpointing: ScanCheckpointSpec = True scan_layers: bool = True @@ -128,14 +173,15 @@ class MixtralConfig(MistralConfig): HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) def __post_init__(self): + # Keep `max_seq_len` >= `seq_len` so `default_train` checks (train_seq_len <= max_seq_len) work, + # while still allowing packed sequences longer than the attention window for sliding-window attention. + object.__setattr__(self, "max_seq_len", max(self.max_seq_len, self.seq_len)) super().__post_init__() assert ( self.num_experts_per_tok <= self.n_routed_experts ), f"num_experts_per_tok={self.num_experts_per_tok} greater than by n_routed_experts={self.n_routed_experts}." - def hf_checkpoint_converter( - self, ref_checkpoint: str | None = None - ) -> HFCheckpointConverter["MixtralConfig"]: # type: ignore + def hf_checkpoint_converter(self, ref_checkpoint: str | None = None) -> HFCheckpointConverter["MixtralConfig"]: # type: ignore return HFCheckpointConverter( self.__class__, reference_checkpoint=self.reference_checkpoint if ref_checkpoint is None else ref_checkpoint, @@ -209,7 +255,7 @@ def mk_LayerNorm(self, axis: Axis) -> hnn.RmsNorm: axis, eps=self.layer_norm_epsilon, use_weight=self.use_layer_norm_weight, use_bias=self.use_bias ) - def flops_per_token(self, vocab_size: int, context_length: int): + def flops_per_token(self, vocab_size: int, context_length: int) -> float: return lm_flops_per_token( hidden_dim=self.hidden_dim, intermediate_dim=self.intermediate_dim, @@ -245,6 +291,9 @@ def total_trainable_params(self, vocab_size): def attention_config(self) -> AttentionConfig: """Convert this MixtralConfig to an AttentionConfig for use with Attention.""" + qk_norm = self.qk_norm + if qk_norm is None and self.use_qk_norm: + qk_norm = RmsNormConfig(eps=self.qk_norm_eps, use_weight=True, use_bias=False) return AttentionConfig( Embed=self.Embed, num_heads=self.num_heads, @@ -254,6 +303,7 @@ def attention_config(self) -> AttentionConfig: attn_backend=self.attn_backend, flash_attention_block_size=self.flash_attention_block_size, rope=self.rope, + qk_norm=qk_norm, ) @@ -288,18 +338,33 @@ def init( @named_call def __call__(self, x: NamedArray, group_sizes: NamedArray, *, key=None) -> NamedArray: + """ + Apply a gated linear unit (GLU) transformation with grouped operations. + + This method implements a feed-forward network layer with gating mechanism, + commonly used in transformer-based models like Mixtral. The computation follows + the pattern: (W1(x) * activation) * W3(x) -> W2(gated_output). + + Args: + x (NamedArray): Input tensor with named dimensions. + group_sizes (NamedArray): Array specifying the size of groups for grouped + linear transformations. Used to partition the input/output dimensions + into groups, enabling efficient computation or mixture-of-experts style + processing where different groups are processed separately. + key (optional): Random number generator key for reproducibility in + stochastic operations. Split into 3 keys for the three linear layers. + + Returns: + NamedArray: Output tensor after applying the gated feed-forward + transformation, with the same named dimensions as the input. + """ k1, k2, k3 = maybe_rng_split(key, 3) w1_output = self.w1(x, group_sizes, key=k1) - activated = self.act(w1_output) - w3_output = self.w3(x, group_sizes, key=k3) - gated = activated * w3_output - final_output = self.w2(gated, group_sizes, key=k2) - return final_output def to_state_dict(self, prefix: str | None = None) -> StateDict: @@ -315,6 +380,20 @@ def to_state_dict(self, prefix: str | None = None) -> StateDict: return out + def from_state_dict(self, state_dict: StateDict, prefix: str | None = None) -> "MixtralMoEMlp": + w: list[list[Array]] = [[], [], []] + num_experts = self.w1.Experts.size + for i in range(num_experts): + for j in range(3): + key = f"{prefix}.{i}.w{j + 1}.weight" + val = jnp.swapaxes(state_dict[key], -1, -2)[..., None, :, :] + w[j].append(val) + + for j in range(3): + w[j] = jnp.concat(w[j], axis=1) + + return eqx.tree_at(lambda m: [m.w1.weight.array, m.w2.weight.array, m.w3.weight.array], self, w) + class MixtralSparseMoeBlock(eqx.Module): """Mixture-of-Experts""" @@ -322,6 +401,7 @@ class MixtralSparseMoeBlock(eqx.Module): config: MistralConfig = eqx.field(static=True) gate: hnn.Linear # projection from Embed to Experts experts: MixtralMoEMlp + router_bias: NamedArray @staticmethod def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": @@ -338,36 +418,54 @@ def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": use_gmm=config.use_gmm, ) - return MixtralSparseMoeBlock(config, gate, experts) + router_bias = hax.zeros(config.Experts) + + return MixtralSparseMoeBlock(config, gate, experts, router_bias) - def _route(self, router_probs: NamedArray, Token: Axis, TopExperts: Axis): + def _route( + self, + selection_logits: NamedArray, + router_logits: NamedArray, + Token: Axis, + TopExperts: Axis, + *, + topk_then_softmax: bool, + ): @partial( shard_map, mesh=hax.partitioning._get_mesh(), - in_specs=hax.partitioning.pspec_for_axis(router_probs.axes), + in_specs=( + hax.partitioning.pspec_for_axis(selection_logits.axes), + hax.partitioning.pspec_for_axis(router_logits.axes), + ), out_specs=( hax.partitioning.pspec_for_axis((Token, TopExperts)), hax.partitioning.pspec_for_axis((Token, TopExperts)), ), **_SHARD_MAP_CHECK_KWARGS, ) - def sharded_route(router_probs_): - selected_weights_, selected_experts_ = jax.lax.top_k(router_probs_, TopExperts.size) - selected_weights_ = selected_weights_ / selected_weights_.sum(-1, keepdims=True) + def sharded_route(selection_logits_: Array, router_logits_: Array): + _selected_scores, selected_experts_ = jax.lax.top_k(selection_logits_, TopExperts.size) + + if topk_then_softmax: + selected_logits_ = jnp.take_along_axis(router_logits_, selected_experts_, axis=-1) + selected_weights_ = jax.nn.softmax(selected_logits_, axis=-1) + else: + router_probs_ = jax.nn.softmax(router_logits_, axis=-1) + selected_weights_ = jnp.take_along_axis(router_probs_, selected_experts_, axis=-1) + selected_weights_ = selected_weights_ / selected_weights_.sum(-1, keepdims=True) return selected_weights_, selected_experts_ with jax.named_scope("route"): - selected_weights_, selected_experts_ = sharded_route(router_probs.array) + selected_weights_, selected_experts_ = sharded_route(selection_logits.array, router_logits.array) - selected_weights = NamedArray(selected_weights_, (Token, TopExperts)) - selected_experts = NamedArray(selected_experts_, (Token, TopExperts)) + selected_weights = hax.named(selected_weights_, (Token, TopExperts)) + selected_experts = hax.named(selected_experts_, (Token, TopExperts)) return selected_weights, selected_experts - def _permute( - self, x_flat: NamedArray, topk_idx_flat: NamedArray, TokenRepeat: Axis - ) -> tuple[NamedArray, NamedArray, NamedArray, Axis]: + def _permute(self, x_flat: NamedArray, topk_idx_flat: NamedArray, TokenRepeat: Axis): Experts = self.config.Experts @partial( @@ -393,15 +491,11 @@ def permute_sharded(x_flat_: Array, topk_idx_flat_: Array): with jax.named_scope("permute"): x_repeat_sort_, group_sizes_, sort_idx_ = permute_sharded(x_flat.array, topk_idx_flat.array) - # tensor shapes reported inside shard_map are per-shard. Using that raw length to size the axis causes - # mismatches once the full global array is materialized (e.g., when calling the GMM kernel). - # Instead, keep the original TokenRepeat axis length, which already encodes the global token-repeat count. - token_repeat_axis = Axis(TokenRepeat.name, TokenRepeat.size) - x_repeat_sort = NamedArray(x_repeat_sort_, (token_repeat_axis, self.config.Embed)) - group_sizes = NamedArray(group_sizes_, (Experts,)) - sort_idx = NamedArray(sort_idx_, (token_repeat_axis,)) + x_repeat_sort = hax.named(x_repeat_sort_, (TokenRepeat, self.config.Embed)) + group_sizes = hax.named(group_sizes_, (Experts,)) + sort_idx = hax.named(sort_idx_, (TokenRepeat,)) - return x_repeat_sort, group_sizes, sort_idx, token_repeat_axis + return x_repeat_sort, group_sizes, sort_idx def _unpermute( self, @@ -431,12 +525,12 @@ def unpermute_sharded(out_repeat_sort_: Array, sort_idx_: Array): with jax.named_scope("unpermute"): out_repeat_unflat_ = unpermute_sharded(out_repeat_sort.array, sort_idx.array) - out_repeat_unflat = NamedArray(out_repeat_unflat_, (Token, TopExperts, self.config.Embed)) + out_repeat_unflat = hax.named(out_repeat_unflat_, (Token, TopExperts, self.config.Embed)) return out_repeat_unflat @named_call - def __call__(self, x: NamedArray, *, key=None) -> NamedArray: + def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> NamedArray: if x.has_axis("batch"): squash_axes = [x.resolve_axis("batch"), x.resolve_axis(self.config.Pos.name)] else: @@ -449,17 +543,83 @@ def __call__(self, x: NamedArray, *, key=None) -> NamedArray: x_flat = hax.flatten_axes(x, old_axes=squash_axes, new_axis="token") # [Batch, Pos, Embed] -> [Token, Embed] Token = x_flat.resolve_axis("token") - router_logits = self.gate(x_flat, key=k_gate) - + # Optionally compute router math in fp32 for numerical stability. + router_fp32 = getattr(self.config, "router_fp32", False) + x_for_gate = x_flat.astype(jnp.float32) if router_fp32 else x_flat + router_logits = self.gate(x_for_gate, key=k_gate) + if router_fp32 and router_logits.array.dtype != jnp.float32: + router_logits = router_logits.astype(jnp.float32) router_probs = hnn.softmax(router_logits, axis=Experts) - topk_weights, topk_idx = self._route(router_probs, Token, TopExperts) + selection_logits = router_logits + if getattr(self.config, "alf_lb_loss_scale", 0.0) > 0: + bias = self.router_bias + if getattr(self.config, "alf_lb_center_bias", True): + bias = bias - hax.mean(bias, axis=Experts) + if router_fp32 and bias.array.dtype != jnp.float32: + bias = bias.astype(jnp.float32) + selection_logits = selection_logits + bias + + topk_weights, topk_idx = self._route( + selection_logits, + router_logits, + Token, + TopExperts, + topk_then_softmax=getattr(self.config, "router_topk_then_softmax", False), + ) + if router_fp32 and topk_weights.array.dtype != x_flat.array.dtype: + # Keep routing decisions in fp32, but cast weights back to the model activation dtype to + # avoid upcasting the expert weighted-sum (bandwidth/memory). + topk_weights = topk_weights.astype(x_flat.array.dtype) + + if force_dense: + idx_arr = jnp.zeros_like(topk_idx.array) + w_arr = jnp.zeros_like(topk_weights.array) + w_arr = w_arr.at[:, 0].set(1.0) + topk_weights = hax.named(w_arr, (Token, TopExperts)) + topk_idx = hax.named(idx_arr, (Token, TopExperts)) topk_idx_flat = hax.flatten_axes(topk_idx, old_axes=[Token, TopExperts], new_axis="token_repeat") TokenRepeat = topk_idx_flat.resolve_axis("token_repeat") - x_repeat_sort, group_sizes, sort_idx, TokenRepeat = self._permute(x_flat, topk_idx_flat, TokenRepeat) - - out_repeat_sort = self.experts(x_repeat_sort, group_sizes, key=k_experts) + x_repeat_sort, group_sizes, sort_idx = self._permute(x_flat, topk_idx_flat, TokenRepeat) + + if self.config.expert_parallelism is not None: + if not self.config.enable_expert_dispatch: + raise NotImplementedError( + "expert_parallelism is set, but expert dispatch is disabled. " + "Sharding the 'experts' axis without cross-device token dispatch (e.g. all-to-all) " + "does not implement true expert-parallel MoE." + ) + + ep_axis = self.config.expert_mesh_axis + mesh = hax.partitioning._get_mesh() + if mesh is None: + raise ValueError("expert_parallelism is set but no JAX mesh is active.") + if ep_axis not in mesh.shape: + raise ValueError( + f"expert_mesh_axis={ep_axis!r} is not a mesh axis. Available axes: {sorted(mesh.shape.keys())}" + ) + + ep_size = self.config.expert_parallelism + mesh_ep_size = mesh.shape[ep_axis] + if mesh_ep_size != ep_size: + raise ValueError( + f"expert_parallelism={ep_size} does not match mesh[{ep_axis!r}]={mesh_ep_size}. " + "Adjust your MeshConfig axes or disable expert_parallelism." + ) + if ep_size != self.config.n_routed_experts: + raise NotImplementedError( + "Expert dispatch currently requires expert_parallelism == n_routed_experts " + f"({ep_size} != {self.config.n_routed_experts})." + ) + + raise NotImplementedError( + "Expert-parallel dispatch is not implemented yet. " + "Set expert_parallelism=None (default) to run with replicated experts, " + "or contribute an all-to-all token dispatch implementation here." + ) + else: + out_repeat_sort = self.experts(x_repeat_sort, group_sizes, key=k_experts) out_repeat_unflat = self._unpermute( out_repeat_sort, sort_idx, topk_weights, Token, TokenRepeat, TopExperts @@ -470,22 +630,40 @@ def __call__(self, x: NamedArray, *, key=None) -> NamedArray: # aux loss extras = {} expert_loads = group_sizes / hax.sum(group_sizes, axis=Experts) - extras = { "expert_loads": expert_loads, } - if self.config.lbl_coef is not None: + if self.config.lbl_coef is not None and getattr(self.config, "alf_lb_loss_scale", 0.0) <= 0: + # Shapes: + # - expert_loads: [Experts] where Experts.size == n_routed_experts + # - router_probs: [Token, Experts] where Token is the flattened token axis (T = B*S) f = expert_loads * self.config.n_routed_experts / self.config.num_experts_per_tok - p = hax.mean(router_probs, axis=Token) - extras["load_balancing_loss"] = self.config.lbl_coef * hax.sum(f * p, axis=Experts) + # - f: [Experts] + p = hax.mean(router_probs, axis=Token) # [Token, Experts] -> [Experts] + extras["load_balancing_loss"] = self.config.lbl_coef * hax.sum(f * p, axis=Experts) # [] (scalar) if self.config.rzl_coef is not None: extras["router_z_loss"] = self.config.rzl_coef * hax.mean( - hnn.logsumexp(router_logits, axis=Experts) ** 2, axis=Token + # router_logits: [Token, Experts] -> logsumexp: [Token] -> mean over Token: [] (scalar) + hnn.logsumexp(router_logits, axis=Experts) ** 2, + axis=Token, ) + alf_scale = float(getattr(self.config, "alf_lb_loss_scale", 0.0)) + if alf_scale > 0: + target = TokenRepeat.size / Experts.size + delta_arr = group_sizes.array - target + if getattr(self.config, "alf_lb_use_sign", True): + delta_arr = jnp.sign(delta_arr) + delta_arr = jax.lax.stop_gradient(delta_arr) + delta = hax.named(delta_arr, Experts) + extras["alf_lb_bias_loss"] = alf_scale * hax.sum(self.router_bias * delta, axis=Experts) + return hax.unflatten_axis(out, axis=Token, new_axes=squash_axes), extras # [Batch, Pos, Embed] +MixtralConfig = CustomMixtralConfig + + class MixtralDecoderLayer(eqx.Module): config: MixtralConfig = eqx.field(static=True) self_attn: Attention @@ -493,9 +671,10 @@ class MixtralDecoderLayer(eqx.Module): input_layernorm: hnn.RmsNorm post_attention_layernorm: hnn.RmsNorm shared_mlp: LlamaMlp | None + force_dense_moe: bool @staticmethod - def init(config: MistralConfig, *, key) -> "MixtralDecoderLayer": + def init(config: MistralConfig, force_dense_moe: bool = False, *, key) -> "MixtralDecoderLayer": k_attn, k_moe, k_mlp = jrandom.split(key, 3) attn_config = config.attention_config() @@ -512,7 +691,15 @@ def init(config: MistralConfig, *, key) -> "MixtralDecoderLayer": key=k_mlp, use_bias=config.use_bias, ) - return MixtralDecoderLayer(config, attn, block_sparse_moe, ln_1, ln_2, shared_mlp) + return MixtralDecoderLayer( + config, + attn, + block_sparse_moe, + ln_1, + ln_2, + shared_mlp, + bool(force_dense_moe), + ) @named_call def __call__(self, x: NamedArray, mask: NamedArray | AttentionMask | None, *, key=None) -> NamedArray: @@ -527,7 +714,7 @@ def __call__(self, x: NamedArray, mask: NamedArray | AttentionMask | None, *, ke residual = x x = self.post_attention_layernorm(x) mlp_output = self.shared_mlp(x, key=k_mlp) if self.shared_mlp is not None else 0 - moe_output, extras = self.block_sparse_moe(x, key=k_mlp) + moe_output, extras = self.block_sparse_moe(x, key=k_mlp, force_dense=self.force_dense_moe) output = residual + mlp_output + moe_output return output, extras @@ -547,6 +734,7 @@ def init(config: MistralConfig, *, key) -> "MixtralTransformer": layers = S.init(config.Layers, MixtralDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( config, + force_dense_moe=False, key=shaped_rng_split(key, config.num_layers), ) ln_f = config.mk_LayerNorm(config.Embed) @@ -557,31 +745,44 @@ def init(config: MistralConfig, *, key) -> "MixtralTransformer": def __call__( self, x: NamedArray, attn_mask: NamedArray | None, *, key, pos_ids: NamedArray | None = None ) -> NamedArray: - _log_libtpu_args_once() keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None x, extras = self.layers.scan(x, mask=attn_mask, key=keys) x = self.norm(x) - # moe logging - expert_loads = extras["expert_loads"] - entropy = -hax.sum(expert_loads * hax.log(expert_loads + 1e-6), axis=self.config.Experts) - - stats = {} - for i in range(self.config.num_layers): - stats[f"moe/layer{i}/routing_entropy"] = jax.lax.stop_gradient(entropy.array[i]) - for j in range(self.config.n_routed_experts): - stats[f"moe/layer{i}/expert{j}_load"] = jax.lax.stop_gradient(expert_loads.array[i, j]) - - if self.config.lbl_coef is not None: + if "load_balancing_loss" in extras: extras["load_balancing_loss"] = hax.sum(extras["load_balancing_loss"], axis=self.config.Layers) - # Use stop_gradient to ensure concrete values are logged, not tracers - stats["train/load_balancing_loss"] = jax.lax.stop_gradient(extras["load_balancing_loss"].array) - if self.config.rzl_coef is not None: + if "router_z_loss" in extras: extras["router_z_loss"] = hax.sum(extras["router_z_loss"], axis=self.config.Layers) - # Use stop_gradient to ensure concrete values are logged, not tracers + if "alf_lb_bias_loss" in extras: + extras["alf_lb_bias_loss"] = hax.sum(extras["alf_lb_bias_loss"], axis=self.config.Layers) + stats: dict[str, Array] = {} + if "load_balancing_loss" in extras: + stats["train/load_balancing_loss"] = jax.lax.stop_gradient(extras["load_balancing_loss"].array) + if "router_z_loss" in extras: stats["train/router_z_loss"] = jax.lax.stop_gradient(extras["router_z_loss"].array) - - levanter.tracker.jit_log(stats) + if "alf_lb_bias_loss" in extras: + stats["train/alf_lb_bias_loss"] = jax.lax.stop_gradient(extras["alf_lb_bias_loss"].array) + + if self.config.log_moe_metrics: + expert_loads = extras["expert_loads"] + entropy = -hax.sum(expert_loads * hax.log(expert_loads + 1e-6), axis=self.config.Experts) + load_violation = expert_loads * self.config.n_routed_experts - 1.0 + load_violation_max = hax.max(load_violation, axis=self.config.Experts) + global_load_violation_max = hax.max(load_violation_max, axis=self.config.Layers) + + for i in range(self.config.num_layers): + stats[f"moe/layer{i}/routing_entropy"] = jax.lax.stop_gradient(entropy.array[i]) + stats[f"moe/layer{i}/load_violation_max"] = jax.lax.stop_gradient(load_violation_max.array[i]) + for j in range(self.config.n_routed_experts): + stats[f"moe/layer{i}/expert{j}_load"] = jax.lax.stop_gradient(expert_loads.array[i, j]) + stats["moe/load_violation_max"] = jax.lax.stop_gradient(global_load_violation_max.array) + dense_first_n_layers = int(getattr(self.config, "dense_first_n_layers", 0) or 0) + if 0 < dense_first_n_layers < self.config.num_layers: + sparse_layers = jnp.arange(self.config.num_layers) >= dense_first_n_layers + sparse_load_violation_max = jnp.where(sparse_layers, load_violation_max.array, -jnp.inf) + stats["moe/load_violation_max_sparse_layers"] = jax.lax.stop_gradient(jnp.max(sparse_load_violation_max)) + if stats: + levanter.tracker.jit_log(stats) return x, extras @@ -663,10 +864,12 @@ def activations( x, extras = self.transformer(x, attn_mask=attn_mask, key=key, pos_ids=pos_ids) aux_loss = 0 - if self.config.lbl_coef is not None: + if "load_balancing_loss" in extras: aux_loss += extras["load_balancing_loss"] - if self.config.rzl_coef is not None: + if "router_z_loss" in extras: aux_loss += extras["router_z_loss"] + if "alf_lb_bias_loss" in extras: + aux_loss += extras["alf_lb_bias_loss"] return x, aux_loss def get_lm_head(self) -> hax.NamedArray: @@ -675,6 +878,50 @@ def get_lm_head(self) -> hax.NamedArray: else: return self.lm_head.weight + def compute_next_token_loss( # type: ignore[override] + self, + example: LmExample, + *, + key=None, + reduction: hax.ReductionFunction | None = hax.mean, + reduction_axis: hax.AxisSelection | None = None, + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype | None = jnp.float32, + logit_soft_cap: float | None = None, + ) -> jnp.ndarray | NamedArray: + activations = self.activations(example.tokens, example.attn_mask, key=key) + + aux_loss = 0 + if isinstance(activations, tuple): + activations, aux_loss = activations + + implementation = self.config.cross_entropy_implementation or "auto" + if implementation in ("auto", "pallas_tpu", "legacy"): + block_size = self.config.cross_entropy_block_size + if block_size is None: + raise ValueError("cross_entropy_block_size must be set for fused cross-entropy") + elif implementation in ("xla", "reference"): + block_size = None + else: + raise ValueError(f"Unknown cross_entropy_implementation: {implementation!r}") + + loss = maybe_fused_next_token_loss( + self.Pos, + self.Embed, + self.Vocab, + activations, + self.get_lm_head(), + example.tokens, + loss_weight=example.loss_weight, + reduction=reduction, + reduction_axis=reduction_axis, + logsumexp_weight=logsumexp_weight, + block_size=block_size, + dtype=loss_dtype, + logit_soft_cap=logit_soft_cap, + ) + return loss + aux_loss + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MixtralConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) diff --git a/experiments/speedrun/grugformer_moe/grugformer_moe.py b/experiments/speedrun/grugformer_moe/grugformer_moe.py new file mode 100644 index 0000000000..a58bae86b0 --- /dev/null +++ b/experiments/speedrun/grugformer_moe/grugformer_moe.py @@ -0,0 +1,956 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +""" +Grugformer MoE experiment (router + ragged expert MLP). + +This is an experiment-only implementation: it keeps the MoE logic local to this entrypoint +and does not modify the canonical `levanter.grug` core. + +Design goals: +- "Grug simple": explicit tensor shapes, minimal abstractions. +- "Vanilla custom_mixtral logic": top-k routing + sort/permute dispatch + GMM (Megablox) or ragged_dot_general expert MLP, + with load-balancing loss and router z-loss. +- Replicated experts (no expert-parallel all-to-all). +""" + +# nodryrun + +import dataclasses +import logging +import os +from dataclasses import dataclass +from functools import partial +from typing import TypeVar + +import haliax as hax +import jax +import jax.numpy as jnp +import jax.scipy as jsp +from einops import rearrange +from fray.cluster import ResourceConfig +from haliax import Axis, NamedArray +from haliax.nn.linear import gmm_sharded +from haliax.partitioning import _get_mesh +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P, reshard +from jax.tree_util import register_dataclass +from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree + +from levanter.grug.attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention +from levanter.grug.sharding import Pbatch_moe, Pvocab, unshard +from levanter.layers.attention import AttentionMask as LevanterAttentionMask +from levanter.models.loss import maybe_fused_next_token_loss +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel +from levanter.utils.flop_utils import lm_flops_per_token +from marin.execution.executor import executor_main +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun + +from experiments.llama import llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + +# Debug helper: only enabled when explicitly requested via env var to avoid impacting performance. +_DEBUG_FINITE = os.environ.get("GRUGMOE_DEBUG_FINITE", "") not in ("", "0", "false", "False") + + +def _maybe_log_nonfinite(x: jax.Array, *, name: str) -> jax.Array: + """Print a one-line summary if `x` contains NaN/Inf (process 0 only).""" + if not _DEBUG_FINITE or jax.process_index() != 0: + return x + + is_finite = jnp.all(jnp.isfinite(x)) + + def _print(_: None) -> jax.Array: + x_f32 = x.astype(jnp.float32) + x_clean = jnp.nan_to_num(x_f32, nan=0.0, posinf=0.0, neginf=0.0) + jax.debug.print( + "NONFINITE {name}: any_nan={any_nan} any_inf={any_inf} min={minv} max={maxv}", + name=name, + any_nan=jnp.any(jnp.isnan(x_f32)), + any_inf=jnp.any(jnp.isinf(x_f32)), + minv=jnp.min(x_clean), + maxv=jnp.max(x_clean), + ) + return x + + return jax.lax.cond(is_finite, lambda _: x, _print, operand=None) + +# Ruff/Pyflakes treats string literals in annotations as forward refs and checks that bare +# identifiers (e.g. "D") are defined. These are jaxtyping dimension labels, not runtime symbols. +D = TypeVar("D") + + +def _pbatch() -> P: + return Pbatch_moe() + + +#### Conventions +# +# Mesh meanings: +# - "data": data parallel sharding axis. We also shard parameters over this axis (ZeRO-ish). +# - "model": model parallel sharding axis (TP). +# +# Dim names used in comments: +# - B = batch +# - S = sequence length +# - D = hidden dim +# - I = intermediate dim (per-expert) +# - E = number of routed experts +# - K = experts per token (top-k) +# - T = flattened tokens (= B*S) +# - TR = token-repeat (= T*K) + + +@dataclass(frozen=True) +class GrugMoeModelConfig: + # Core grug hyperparams + vocab_size: int + hidden_dim: int = 2048 + intermediate_dim: int = 5632 + num_layers: int = 24 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + max_seq_len: int = 4096 + layer_norm_eps: float = 1e-5 + initializer_std: float = 0.02 + rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig) + cross_entropy_block_size: int | None = 32768 + cross_entropy_implementation: str | None = "xla" + + # MoE hyperparams (vanilla Mixtral-ish) + n_routed_experts: int = 8 # E + num_experts_per_tok: int = 2 # K + lbl_coef: float | None = 0.01 + rzl_coef: float | None = 0.001 + router_fp32: bool = False + router_topk_then_softmax: bool = False + use_gmm: bool = True + + def __post_init__(self) -> None: + _ = self.inferred_head_dim + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads for grouped-query attention") + if self.vocab_size <= 0: + raise ValueError("vocab_size must be positive") + if self.max_seq_len <= 0: + raise ValueError("max_seq_len must be positive") + if self.num_experts_per_tok <= 0: + raise ValueError("num_experts_per_tok must be positive") + if self.n_routed_experts <= 0: + raise ValueError("n_routed_experts must be positive") + if self.num_experts_per_tok > self.n_routed_experts: + raise ValueError("num_experts_per_tok cannot exceed n_routed_experts") + + @property + def inferred_head_dim(self) -> int: + if self.head_dim is not None: + return self.head_dim + if self.hidden_dim % self.num_heads != 0: + raise ValueError( + f"hidden_dim={self.hidden_dim} is not divisible by num_heads={self.num_heads}; set head_dim explicitly" + ) + return self.hidden_dim // self.num_heads + + +@register_dataclass +@dataclass(frozen=True) +class GrugAttentionParams: + w_q: jax.Array + w_k: jax.Array + w_v: jax.Array + w_o: jax.Array + + +@register_dataclass +@dataclass(frozen=True) +class GrugMoeBlockParams: + attn: GrugAttentionParams + rms_attn: jax.Array + rms_mlp: jax.Array + router_w: jax.Array # [D, E] + w1: jax.Array # [E, D, I] (gate_proj) + w3: jax.Array # [E, D, I] (up_proj) + w2: jax.Array # [E, I, D] (down_proj) + + +@register_dataclass +@dataclass(frozen=True) +class GrugMoeParameters: + token_embed: jax.Array + output_proj: jax.Array + blocks: tuple[GrugMoeBlockParams, ...] + final_norm: jax.Array + + +def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]: + return std * jax.random.truncated_normal(key, -3, 3, shape) + + +@partial(jax.jit, static_argnames=("cfg",)) +def init_parameters(cfg: GrugMoeModelConfig, *, key: PRNGKeyArray) -> GrugMoeParameters: + head_dim = cfg.inferred_head_dim + key, embed_key, out_key = jax.random.split(key, 3) + layer_keys = jax.random.split(key, cfg.num_layers) + + token_embed = reshard(_init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), Pvocab) + output_proj = reshard(_init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std), Pvocab) + final_norm = reshard(jnp.ones((cfg.hidden_dim,), dtype=jnp.float32), P(None)) + + blocks: list[GrugMoeBlockParams] = [] + # extract shape sizes for brevity and consistency + hidden_dim = cfg.hidden_dim + num_heads = cfg.num_heads + num_kv_heads = cfg.num_kv_heads + intermediate_dim = cfg.intermediate_dim + num_experts = cfg.n_routed_experts + for i in range(cfg.num_layers): + ( + k_q, + k_k, + k_v, + k_o, + k_router, + k_w1, + k_w2, + k_w3, + ) = jax.random.split(layer_keys[i], 8) + + attn = GrugAttentionParams( + w_q=reshard(_init_weight(k_q, (hidden_dim, num_heads * head_dim), cfg.initializer_std), P("data", "model")), + w_k=reshard( + _init_weight(k_k, (hidden_dim, num_kv_heads * head_dim), cfg.initializer_std), P("data", "model") + ), + w_v=reshard( + _init_weight(k_v, (hidden_dim, num_kv_heads * head_dim), cfg.initializer_std), P("data", "model") + ), + w_o=reshard(_init_weight(k_o, (num_heads * head_dim, hidden_dim), cfg.initializer_std), P("model", "data")), + ) + + # Router maps D -> E. Keep the expert axis replicated (no expert-parallel sharding). + router_w = reshard(_init_weight(k_router, (hidden_dim, num_experts), cfg.initializer_std), P("data", None)) + + # Expert weights are replicated over the data axis and sharded over the model axis (TP). + # This keeps the GMM pathway simple and avoids ragged-dot auto-sharding pathologies. + w1 = reshard( + _init_weight(k_w1, (num_experts, hidden_dim, intermediate_dim), cfg.initializer_std), + P(None, None, "model"), + ) + w3 = reshard( + _init_weight(k_w3, (num_experts, hidden_dim, intermediate_dim), cfg.initializer_std), + P(None, None, "model"), + ) + w2 = reshard( + _init_weight(k_w2, (num_experts, intermediate_dim, hidden_dim), cfg.initializer_std), + P(None, "model", None), + ) + + # keep rms replicated + rms_attn = jnp.ones((hidden_dim,), dtype=jnp.float32) + rms_mlp = jnp.ones((hidden_dim,), dtype=jnp.float32) + + blocks.append( + GrugMoeBlockParams( + attn=attn, + rms_attn=rms_attn, + rms_mlp=rms_mlp, + router_w=router_w, + w1=w1, + w3=w3, + w2=w2, + ) + ) + + return GrugMoeParameters( + token_embed=token_embed, + output_proj=output_proj, + blocks=tuple(blocks), + final_norm=final_norm, + ) + + +def rms_norm(x: Float[Array, "... D"], weight: Float[Array, "D"], eps: float) -> Float[Array, "... D"]: + weight = unshard(weight) + # Levanter runs with mixed precision (bf16 compute, fp32 params) + strict dtype promotion. + # Do RMSNorm math in fp32, then cast back to the input dtype. + out_dtype = x.dtype + x_f32 = x.astype(jnp.float32) + w_f32 = weight.astype(jnp.float32) + variance = jnp.mean(jnp.square(x_f32), axis=-1, keepdims=True) + inv = jax.lax.rsqrt(variance + eps) + y = (x_f32 * inv) * w_f32 + return y.astype(out_dtype) + + +def _ragged_moe_linear( + x: jax.Array, + w: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Ragged MoE linear: (TR, In) x (E, In, Out) with groups along TR. + + Shapes: + - x: [TR, In] + - w: [E, In, Out] + - group_sizes: [E] (sum == TR) + - out: [TR, Out] + """ + # Everything other than the contracting dimension is treated as ragged. + dim_numbers = jax.lax.RaggedDotDimensionNumbers( + dot_dimension_numbers=(((1,), (1,)), ((), ())), + lhs_ragged_dimensions=(0,), + rhs_group_dimensions=(0,), + ) + # `ragged_dot_general` doesn't yet have a built-in sharding rule. On SPMD runs we + # drop this op into full auto-sharding mode. + # + # This MoE is still *not* expert-parallel: expert weights are replicated and all routing + # / dispatch happens per-device for that device's local tokens. + mesh = _get_mesh() + if mesh is not None and not getattr(mesh, "empty", False): + w_sharding = getattr(w, "sharding", None) + w_spec = getattr(w_sharding, "spec", None) + out_axis = w_spec[-1] if w_spec is not None and len(w_spec) == w.ndim else None + batch_spec = _pbatch()[0] + out_sharding = NamedSharding(mesh, P(batch_spec, out_axis)) + + ragged = jax.sharding.auto_axes( + lambda lhs, rhs, gs: jax.lax.ragged_dot_general( + lhs=lhs, + rhs=rhs, + group_sizes=gs, + ragged_dot_dimension_numbers=dim_numbers, + ) + ) + try: + return ragged(x, w, group_sizes, out_sharding=out_sharding) + except TypeError: + # Some JAX builds spell this kwarg differently. + return ragged(x, w, group_sizes, out_shardings=out_sharding) # type: ignore[call-arg] + + return jax.lax.ragged_dot_general( + lhs=x, + rhs=w, + group_sizes=group_sizes, + ragged_dot_dimension_numbers=dim_numbers, + ) + + +def _gmm_moe_linear( + x: jax.Array, + w: jax.Array, + group_sizes: jax.Array, + *, + w_spec: P, + out_axis: str | None, + ar: bool, +) -> jax.Array: + """GMM-based MoE linear using Megablox grouped matmul. + + Shapes: + - x: [TR, In] + - w: [E, In, Out] + - group_sizes: [E] (sum == TR) + - out: [TR, Out] + """ + mesh = _get_mesh() + if mesh is not None and not getattr(mesh, "empty", False): + out_specs = P(_pbatch()[0], out_axis) + gmm_fn = shard_map( + lambda lhs, rhs, gs: gmm_sharded(lhs, rhs, gs, ar=ar), + mesh=mesh, + in_specs=(_pbatch(), w_spec, P(None)), + out_specs=out_specs, + check_rep=False, + ) + return gmm_fn(x, w, group_sizes) + + return gmm_sharded(x, w, group_sizes, ar=ar) + + + +def _route( + selection_logits: jax.Array, + router_logits: jax.Array, + *, + top_k: int, + topk_then_softmax: bool, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Top-k route tokens to experts. + + Shapes: + - selection_logits: [T, E] + - router_logits: [T, E] + - topk_weights: [T, K] + - topk_idx: [T, K] (int32 expert ids) + - router_probs: [T, E] + """ + router_probs = jax.nn.softmax(router_logits.astype(jnp.float32), axis=-1) + _scores, topk_idx = jax.lax.top_k(selection_logits, top_k) + + if topk_then_softmax: + selected_logits = jnp.take_along_axis(router_logits, topk_idx, axis=-1) + topk_weights = jax.nn.softmax(selected_logits, axis=-1) + else: + topk_weights = jnp.take_along_axis(router_probs, topk_idx, axis=-1) + denom = jnp.sum(topk_weights, axis=-1, keepdims=True) + # Softmax underflow can make `denom==0` in low precision; avoid NaNs by falling back to uniform weights. + topk_weights = jnp.where( + denom > 0, + topk_weights / denom, + jnp.full_like(topk_weights, 1.0 / top_k), + ) + + return topk_weights, topk_idx.astype(jnp.int32), router_probs + + +def _permute( + x_flat: jax.Array, + topk_idx_flat: jax.Array, + *, + num_experts_per_tok: int, + n_routed_experts: int, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Sort token-repeat stream by expert id. + + Shapes: + - x_flat: [T, D] + - topk_idx_flat: [TR] where TR = T*K + - x_repeat_sort: [TR, D] + - group_sizes: [E] + - sort_idx: [TR] + """ + sort_idx = jnp.argsort(topk_idx_flat, axis=-1) + x_repeat_sort = jnp.take(x_flat, sort_idx // num_experts_per_tok, axis=0) + group_sizes = jnp.bincount(topk_idx_flat, length=n_routed_experts).astype(jnp.int32) + return x_repeat_sort, group_sizes, sort_idx.astype(jnp.int32) + + +def _unpermute( + out_repeat_sort: jax.Array, + sort_idx: jax.Array, + *, + num_experts_per_tok: int, + hidden_dim: int, +) -> jax.Array: + """Invert expert sort and unflatten token-repeat back to [T, K, D].""" + inv_sort_idx = jnp.argsort(sort_idx, axis=-1) + out_repeat = jnp.take(out_repeat_sort, inv_sort_idx, axis=0) + return jnp.reshape(out_repeat, (-1, num_experts_per_tok, hidden_dim)) + + +def moe_mlp(block: GrugMoeBlockParams, x: Float[Array, "B S D"], cfg: GrugMoeModelConfig) -> tuple[jax.Array, jax.Array]: + """MoE MLP with Mixtral-style routing/dispatch and auxiliary router losses.""" + B, S, D = x.shape + E = cfg.n_routed_experts + K = cfg.num_experts_per_tok + T = B * S + TR = T * K + + x_flat = jnp.reshape(x, (T, D)) # [B, S, D] -> [T, D] + + x_for_gate = x_flat.astype(jnp.float32) if cfg.router_fp32 else x_flat + router_logits = jnp.einsum("td,de->te", x_for_gate, block.router_w) # [T, D] @ [D, E] -> [T, E] + if cfg.router_fp32 and router_logits.dtype != jnp.float32: + router_logits = router_logits.astype(jnp.float32) + + router_logits = _maybe_log_nonfinite(router_logits, name="router_logits") + + selection_logits = router_logits + + mesh = _get_mesh() + if mesh is not None and not getattr(mesh, "empty", False): + route = shard_map( + lambda sel, rlog: _route( + sel, + rlog, + top_k=K, + topk_then_softmax=cfg.router_topk_then_softmax, + ), + mesh=mesh, + in_specs=(_pbatch(), _pbatch()), + out_specs=(_pbatch(), _pbatch(), _pbatch()), + check_rep=False, + ) + topk_weights, topk_idx, router_probs = route(selection_logits, router_logits) + topk_weights = _maybe_log_nonfinite(topk_weights, name="topk_weights") + router_probs = _maybe_log_nonfinite(router_probs, name="router_probs") + else: + topk_weights, topk_idx, router_probs = _route( + selection_logits, + router_logits, + top_k=K, + topk_then_softmax=cfg.router_topk_then_softmax, + ) + + topk_idx_flat = jnp.reshape(topk_idx, (TR,)) # [T, K] -> [TR] + + if mesh is not None and not getattr(mesh, "empty", False): + permute = shard_map( + lambda x_t, idx_tr: _permute( + x_t, + idx_tr, + num_experts_per_tok=K, + n_routed_experts=E, + ), + mesh=mesh, + in_specs=(_pbatch(), _pbatch()), + out_specs=(_pbatch(), P(None), _pbatch()), + check_rep=False, + ) + x_repeat_sort, group_sizes, sort_idx = permute(x_flat, topk_idx_flat) + else: + x_repeat_sort, group_sizes, sort_idx = _permute( + x_flat, + topk_idx_flat, + num_experts_per_tok=K, + n_routed_experts=E, + ) + + # Expert MLP on the sorted token-repeat stream. All expert math is per-shard (replicated across E). + # + # Shapes: + # - x_repeat_sort: [TR, D] + # - group_sizes: [E], sum(group_sizes) == TR + # - w1/w3: [E, D, I], w2: [E, I, D] + if cfg.use_gmm: + w1_out = _gmm_moe_linear( + x_repeat_sort, + block.w1, + group_sizes, + w_spec=P(None, None, "model"), + out_axis="model", + ar=False, + ) # [TR, I] + w1_out = _maybe_log_nonfinite(w1_out, name="w1_out") + w3_out = _gmm_moe_linear( + x_repeat_sort, + block.w3, + group_sizes, + w_spec=P(None, None, "model"), + out_axis="model", + ar=False, + ) # [TR, I] + w3_out = _maybe_log_nonfinite(w3_out, name="w3_out") + gated = jax.nn.silu(w1_out) * w3_out # [TR, I] + out_repeat_sort = _gmm_moe_linear( + gated, + block.w2, + group_sizes, + w_spec=P(None, "model", None), + out_axis=None, + ar=True, + ) # [TR, D] + else: + w1_out = _ragged_moe_linear(x_repeat_sort, block.w1, group_sizes) # [TR, I] + w1_out = _maybe_log_nonfinite(w1_out, name="w1_out") + w3_out = _ragged_moe_linear(x_repeat_sort, block.w3, group_sizes) # [TR, I] + w3_out = _maybe_log_nonfinite(w3_out, name="w3_out") + gated = jax.nn.silu(w1_out) * w3_out # [TR, I] + out_repeat_sort = _ragged_moe_linear(gated, block.w2, group_sizes) # [TR, D] + out_repeat_sort = _maybe_log_nonfinite(out_repeat_sort, name="out_repeat_sort") + + if mesh is not None and not getattr(mesh, "empty", False): + unpermute = shard_map( + lambda out_tr_d, sidx_tr: _unpermute( + out_tr_d, + sidx_tr, + num_experts_per_tok=K, + hidden_dim=D, + ), + mesh=mesh, + in_specs=(_pbatch(), _pbatch()), + out_specs=_pbatch(), + check_rep=False, + ) + out_repeat_unflat = unpermute(out_repeat_sort, sort_idx) # [T, K, D] + else: + out_repeat_unflat = _unpermute(out_repeat_sort, sort_idx, num_experts_per_tok=K, hidden_dim=D) + + out_flat = jnp.sum(out_repeat_unflat * topk_weights[..., None], axis=1) # [T, D] + out = jnp.reshape(out_flat, (B, S, D)) # [T, D] -> [B, S, D] + + # --- Aux router losses (vanilla Mixtral-ish) --- + aux = jnp.array(0.0, dtype=jnp.float32) + + if cfg.lbl_coef is not None: + # group_sizes: [E] counts assignments over token-repeat stream (TR = T*K) + # expert_loads: [E] sums to 1 + group_sizes_f = group_sizes.astype(jnp.float32) + expert_loads = group_sizes_f / jnp.sum(group_sizes_f) + f = expert_loads * (E / K) # [E] + p = jnp.mean(router_probs.astype(jnp.float32), axis=0) # [T, E] -> [E] + aux = aux + jnp.asarray(cfg.lbl_coef, dtype=jnp.float32) * jnp.sum(f * p) # scalar + + if cfg.rzl_coef is not None: + z = jsp.special.logsumexp(router_logits.astype(jnp.float32), axis=-1) # [T] + aux = aux + jnp.asarray(cfg.rzl_coef, dtype=jnp.float32) * jnp.mean(z**2) # scalar + + return out, aux + + +def _transformer_hidden( + params: GrugMoeParameters, + token_ids: Int[Array, "B S"], + cfg: GrugMoeModelConfig, + *, + mask: AttentionMask | jax.Array | None, +) -> tuple[Float[Array, "B S D"], jax.Array]: + head_dim = cfg.inferred_head_dim + seq_len = token_ids.shape[1] + if _DEBUG_FINITE and jax.process_index() == 0: + bad = jnp.any((token_ids < 0) | (token_ids >= cfg.vocab_size)) + def _print_tok(_: None) -> jax.Array: + jax.debug.print( + "BAD TOKENS: min={minv} max={maxv} vocab={vocab}", + minv=jnp.min(token_ids), + maxv=jnp.max(token_ids), + vocab=cfg.vocab_size, + ) + return token_ids + token_ids = jax.lax.cond(bad, _print_tok, lambda _: token_ids, operand=None) + + if mask is None: + mask = AttentionMask.causal() + + hidden = params.token_embed.at[token_ids].get(out_sharding=_pbatch()) # [B, S, D] + hidden = _maybe_log_nonfinite(hidden, name="hidden/embed") + + aux_total = jnp.array(0.0, dtype=jnp.float32) + + for block in params.blocks: + attn_in = rms_norm(hidden, block.rms_attn, cfg.layer_norm_eps) + q = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_q), "... (n d) -> ... n d", d=head_dim) + k = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_k), "... (m d) -> ... m d", d=head_dim) + v = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_v), "... (m d) -> ... m d", d=head_dim) + q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=cfg.rope) + attn_out = attention(q, k, v, mask) + attn_out = _maybe_log_nonfinite(attn_out, name="attn_out") + attn_out = rearrange(attn_out, "... n d -> ... (n d)") + attn_out = jnp.einsum("bsh,hd->bsd", attn_out, block.attn.w_o, out_sharding=_pbatch()) + + hidden = hidden + attn_out + mlp_in = rms_norm(hidden, block.rms_mlp, cfg.layer_norm_eps) + mlp_out, aux = moe_mlp(block, mlp_in, cfg) + mlp_out = _maybe_log_nonfinite(mlp_out, name="mlp_out") + hidden = hidden + mlp_out + aux_total = aux_total + aux + + hidden = rms_norm(hidden, params.final_norm, cfg.layer_norm_eps) + hidden = _maybe_log_nonfinite(hidden, name="hidden/final") + return hidden, aux_total + + +def activations( + params: GrugMoeParameters, + token_ids: Int[Array, "B S"], + cfg: GrugMoeModelConfig, + *, + mask: AttentionMask | jax.Array | None = None, +) -> tuple[Float[Array, "B S D"], jax.Array]: + """Return final hidden states (and aux loss scalar).""" + return _transformer_hidden(params, token_ids, cfg, mask=mask) + + +class GrugMoeWrapper(LmHeadModel[PyTree]): + """Minimal LmHeadModel wrapper around this experiment-local Grug+MoE implementation.""" + + params: GrugMoeParameters + grug_config: GrugMoeModelConfig + + @property + def config(self) -> GrugMoeModelConfig: + return self.grug_config + + @property + def Pos(self) -> Axis: + return Axis("position", self.grug_config.max_seq_len) + + @property + def KeyPos(self) -> Axis: + return self.Pos.alias("key_position") + + @property + def Vocab(self) -> Axis: + return Axis("vocab", self.grug_config.vocab_size) + + @property + def Embed(self) -> Axis: + return Axis("embed", self.grug_config.hidden_dim) + + @classmethod + def init(cls, Vocab: Axis, config: GrugMoeModelConfig, *, key: PRNGKeyArray) -> "GrugMoeWrapper": + cfg = dataclasses.replace(config, vocab_size=Vocab.size) + params = init_parameters(cfg, key=key) + return cls(params=params, grug_config=cfg) + + def activations( + self, + input_ids: NamedArray, + attn_mask: LevanterAttentionMask | NamedArray | None = None, + *, + key=None, + pos_ids: NamedArray | None = None, + ) -> tuple[NamedArray, jax.Array]: + del key, pos_ids # grug core doesn't use PRNGs/pos_ids yet + + mask = _mask_from_levanter(attn_mask) + hidden, aux = activations(self.params, input_ids.array, self.grug_config, mask=mask) + return hax.named(hidden, (*input_ids.axes, self.Embed)), aux + + def get_lm_head(self) -> NamedArray: + return hax.named(self.params.output_proj, (self.Embed, self.Vocab)) + + def compute_next_token_loss( + self, + example: LmExample, + *, + key=None, + reduction: hax.ReductionFunction | None = hax.mean, + reduction_axis: hax.AxisSelection | None = None, + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype | None = jnp.float32, + logit_soft_cap: float | None = None, + ) -> jnp.ndarray | NamedArray: + activations = self.activations(example.tokens, example.attn_mask, key=key) + + aux_loss = 0 + if isinstance(activations, tuple): + activations, aux_loss = activations + + implementation = self.grug_config.cross_entropy_implementation or "auto" + if implementation in ("auto", "pallas_tpu"): + block_size = self.grug_config.cross_entropy_block_size + if block_size is None: + raise ValueError("cross_entropy_block_size must be set for Pallas cross-entropy") + elif implementation in ("xla", "reference"): + block_size = None + else: + raise ValueError(f"Unknown cross_entropy_implementation: {implementation!r}") + + loss = maybe_fused_next_token_loss( + self.Pos, + self.Embed, + self.Vocab, + activations, + self.get_lm_head(), + example.tokens, + loss_weight=example.loss_weight, + reduction=reduction, + reduction_axis=reduction_axis, + logsumexp_weight=logsumexp_weight, + block_size=block_size, + dtype=loss_dtype, + logit_soft_cap=logit_soft_cap, + ) + return loss + aux_loss + + def resize_vocab(self, new_size: int, key: PRNGKeyArray | None = None) -> "GrugMoeWrapper": + raise NotImplementedError("GrugMoeWrapper does not yet support resizing the vocabulary.") + + +def _mask_from_levanter(attn_mask: LevanterAttentionMask | NamedArray | None) -> AttentionMask | jax.Array | None: + mask: AttentionMask | jax.Array | None = None + if isinstance(attn_mask, LevanterAttentionMask): + if attn_mask.explicit_mask is not None: + raise NotImplementedError("Grug does not support explicit masks yet.") + if attn_mask.causal_offset is not None: + raise NotImplementedError("Grug does not support causal offsets yet.") + segment_ids = None + if attn_mask.segment_ids is not None: + q_seg, kv_seg = attn_mask.segment_ids + segment_ids = (q_seg.array, kv_seg.array) + mask = AttentionMask( + is_causal=attn_mask.is_causal, + segment_ids=segment_ids, + sliding_window=attn_mask.sliding_window, + ) + elif isinstance(attn_mask, NamedArray): + raise NotImplementedError( + "NamedArray attention masks are not supported by Grug (pass a Levanter AttentionMask)." + ) + return mask + + +def _get_num_train_steps(param_count: int, batch_size: int, max_seq_len: int, tpp: int = 20) -> int: + total_tokens = param_count * tpp + return max(1, total_tokens // (batch_size * max_seq_len)) + + +def _size_presets() -> dict[str, "GrugformerMoeConfig"]: + base = dict(max_seq_len=2048, head_dim=None, n_routed_experts=8, num_experts_per_tok=2) + return { + "130m": GrugformerMoeConfig( + hidden_dim=512, intermediate_dim=1792, num_layers=6, num_heads=8, num_kv_heads=8, **base + ), + "300m": GrugformerMoeConfig( + hidden_dim=768, intermediate_dim=2688, num_layers=12, num_heads=12, num_kv_heads=12, **base + ), + "520m": GrugformerMoeConfig( + hidden_dim=1024, intermediate_dim=3584, num_layers=24, num_heads=16, num_kv_heads=16, **base + ), + } + + +def _resource_presets(use_tpu: bool = False): + if use_tpu: + return { + "130m": ResourceConfig.with_tpu("v5p-8"), + "300m": ResourceConfig.with_tpu("v5p-8"), + "520m": ResourceConfig.with_tpu("v5p-8"), + } + return { + "130m": ResourceConfig.with_gpu("A100-80G", count=1), + "300m": ResourceConfig.with_gpu("A100-80G", count=1), + "520m": ResourceConfig.with_gpu("A100-80G", count=2), + } + + +def _batch_sizes() -> dict[str, int]: + return {"130m": 128, "300m": 128, "520m": 128} + + +@LmConfig.register_subclass("grugformer_moe") +@dataclass(frozen=True) +class GrugformerMoeConfig(LmConfig[GrugMoeWrapper]): + """Speedrun LmConfig wrapper around an experiment-local Grug+MoE transformer.""" + + # LmConfig field + max_seq_len: int = 2048 + + # Core hyperparams + hidden_dim: int = 1024 + intermediate_dim: int = 2752 + num_layers: int = 12 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + + # MoE hyperparams + n_routed_experts: int = 8 + num_experts_per_tok: int = 2 + lbl_coef: float | None = 0.01 + rzl_coef: float | None = 0.001 + router_fp32: bool = False + router_topk_then_softmax: bool = False + use_gmm: bool = True + cross_entropy_block_size: int | None = 32768 + cross_entropy_implementation: str | None = "xla" + + # ---- LmConfig API ---- + @property + def model_type(self) -> type[GrugMoeWrapper]: + return GrugMoeWrapper + + @property + def Embed(self) -> Axis: + return Axis("embed", self.hidden_dim) + + def build(self, Vocab: Axis, *, key: PRNGKeyArray) -> GrugMoeWrapper: + cfg = GrugMoeModelConfig( + vocab_size=Vocab.size, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seq_len=self.max_seq_len, + n_routed_experts=self.n_routed_experts, + num_experts_per_tok=self.num_experts_per_tok, + lbl_coef=self.lbl_coef, + rzl_coef=self.rzl_coef, + router_fp32=self.router_fp32, + router_topk_then_softmax=self.router_topk_then_softmax, + use_gmm=self.use_gmm, + cross_entropy_block_size=self.cross_entropy_block_size, + cross_entropy_implementation=self.cross_entropy_implementation, + ) + return GrugMoeWrapper.init(Vocab, cfg, key=key) + + def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: + # Rough FLOP estimate: attention + (MoE MLP per-token uses K experts). + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=context_length, + vocab_size=vocab_size, + glu=True, + ) + + def total_trainable_params(self, vocab_size: int) -> int: + head_dim = self.head_dim or (self.hidden_dim // self.num_heads) + token_embedding = vocab_size * self.hidden_dim + attn = ( + self.hidden_dim * head_dim * self.num_heads + + 2 * self.hidden_dim * head_dim * self.num_kv_heads + + head_dim * self.num_heads * self.hidden_dim + ) + router = self.hidden_dim * self.n_routed_experts + experts = 3 * self.n_routed_experts * self.hidden_dim * self.intermediate_dim + moe = router + experts + transformer = self.num_layers * (attn + moe + 2 * self.hidden_dim) + self.hidden_dim + return int(transformer + 2 * token_embedding) + + +def build_run(size: str, *, use_tpu: bool = False) -> tuple[str, SpeedrunConfig]: + sizes = _size_presets() + if size not in sizes: + raise ValueError(f"Unknown size: {size}") + model_cfg = sizes[size] + + batch = _batch_sizes()[size] + max_seq_len = model_cfg.max_seq_len + params = int(model_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + steps = _get_num_train_steps(params, batch, max_seq_len, tpp=20) + resources = _resource_presets(use_tpu=use_tpu)[size] + + train = SimpleTrainConfig( + resources, + train_seq_len=max_seq_len, + train_batch_size=batch, + num_train_steps=steps, + learning_rate=3e-3, + weight_decay=0.1, + steps_per_eval=500, + steps_per_hf_export=-1, + explicit_mesh_axes=True, + ) + + run_name = f"grugformer_moe_{size}" + desc = f"Grugformer MoE experiment (Mixtral-style router/dispatch) ({size})." + cfg = SpeedrunConfig( + author=Author( + name="__YOUR_NAME__", + affiliation="__YOUR_AFFILIATION__", + url="__YOUR_URL__", + ), + description=desc, + model_config=model_cfg, + train_config=train, + ) + return run_name, cfg + + +def main() -> None: + sizes = ["130m", "300m", "520m"] + use_tpu = bool(int(os.environ.get("SR_USE_TPU", "0"))) + + steps = [] + for s in sizes: + name, cfg = build_run(s, use_tpu=use_tpu) + if cfg.vocab_size != llama3_tokenizer_vocab_size: + raise AssertionError("Speedrun vocab_size mismatch; expected llama3_tokenizer_vocab_size") + cfg.print_run_info() + steps.extend(default_speedrun(name, cfg)) + + executor_main(steps=steps, description="Grugformer MoE experiment (Mixtral-style router/dispatch).") + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py b/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py new file mode 100644 index 0000000000..558be3fa7a --- /dev/null +++ b/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py @@ -0,0 +1,449 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +# nodryrun +"""Grugformer+MoE launcher with Nemotron+DCLM+FineWeb (10B) mixture and eval-harness. + +Intended usage: match the runtime knobs from `experiments/speedrun/olmoe_1b7b_nemotron_40b.py`, but +train the experiment-only Grugformer MoE model. + +Defaults (requested): +- Dataset: nemotron_dclm_fineweb_10b +- Shuffle: feistel +- TPU: v5p-32 +- Global batch size: 64 +- Seq len: 4096 +- Model shape: OLMoE 1B/7B geometry (D=2048, I=1024, L=16, heads=16, kv_heads=8, E=64, K=8) +- Eval-harness: core_plus_leaderboard, both during and post training + +Important: `--dataset-tokenizer` does not retokenize data; it only controls which tokenizer is loaded +for vocab size / special ids / eval decoding. It MUST match the tokenizer used to pretokenize the dataset. +""" + +from __future__ import annotations + +import argparse +import dataclasses +import math +import os +from datetime import timedelta + +from experiments.defaults import default_train +from experiments.evals.task_configs import CORE_TASKS, CORE_TASKS_PLUS_LEADERBOARD, CORE_TASKS_PLUS_MMLU +from experiments.speedrun.grugformer_moe.grugformer_moe import GrugformerMoeConfig +from experiments.speedrun.olmoe_1b7b_nemotron_40b import COMPOSITE_TOKEN_TARGET, DATASET_OPTIONS, DEFAULT_TOKEN_TARGET +from experiments.simple_train_config import SimpleTrainConfig +from fray.cluster import ResourceConfig +from levanter.data.text import LMMixtureDatasetConfig +from marin.execution.executor import ExecutorMainConfig, ExecutorStep, executor_main, output_path_of + +from experiments.speedrun.olmoe_1b7b_nemotron_40b import ( + LevanterEvalHarnessStepConfig, + run_levanter_checkpoint_eval_harness, +) + +LEARNING_RATE = 4e-4 +WEIGHT_DECAY = 0.1 +BETA1 = 0.9 +BETA2 = 0.95 +EPSILON = 1e-8 +MAX_GRAD_NORM = 1.0 +WARMUP_STEPS = 2000 +LR_SCHEDULE = "cosine" +MIN_LR_RATIO = 0.125 +Z_LOSS_WEIGHT = 1e-4 +STEPS_PER_EVAL = 5000 +STEPS_PER_EXPORT = 20_000 + +_EVAL_SUITES: dict[str, tuple] = { + "none": (), + "core": CORE_TASKS, + "core_plus_mmlu": CORE_TASKS_PLUS_MMLU, + "core_plus_leaderboard": CORE_TASKS_PLUS_LEADERBOARD, +} + +_DEFAULT_TPU_TYPE = "v5p-32" +_DEFAULT_SEQ_LEN = 4096 +_DEFAULT_GLOBAL_BATCH_SIZE = 64 +_DEFAULT_EVAL_SUITE = "core_plus_leaderboard" +_DEFAULT_EVAL_SUITE_MODE = "both" + +_SMOKE_TPU_TYPE = "v5p-16" +# Keep smoke tests small enough to avoid fused CE vmem regressions on small slices. +_SMOKE_SEQ_LEN = 1024 +_SMOKE_GLOBAL_BATCH_SIZE = 32 +_SMOKE_NUM_TRAIN_STEPS = 5 +_SMOKE_EVAL_SUITE = "none" +_SMOKE_EVAL_SUITE_MODE = "post_train" + + +def _steps_for_token_target(token_target: int, global_batch_size: int, seq_len: int) -> int: + return math.ceil(token_target / (global_batch_size * seq_len)) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--smoke", + action="store_true", + help=( + "Run a fast TPU sanity check (defaults: v5p-16, bs=32, seq=1024, steps=5, eval-suite=none). " + "If you also pass an explicit flag (e.g., --seq-len), that value is kept." + ), + ) + parser.add_argument( + "--dataset", + choices=DATASET_OPTIONS.keys(), + default="nemotron_dclm_fineweb_10b", + help="Which tokenized dataset preset to train on.", + ) + parser.add_argument("--tpu-type", default=_DEFAULT_TPU_TYPE) + parser.add_argument("--seq-len", type=int, default=_DEFAULT_SEQ_LEN) + parser.add_argument("--global-batch-size", type=int, default=_DEFAULT_GLOBAL_BATCH_SIZE) + parser.add_argument( + "--per-device-parallelism", + type=int, + default=-1, + help=( + "How many examples to process in parallel on each device. -1 (default) chooses a value based on " + "global batch size and device count. Set explicitly (e.g. 8) to reproduce high-MFU legacy-axis runs; " + "train_batch_size must be divisible by per_device_parallelism * data_axis_size." + ), + ) + parser.add_argument( + "--token-target", + type=int, + default=None, + help=( + "Total token budget used to compute default --num-train-steps when that flag is omitted. " + f"Defaults to {DEFAULT_TOKEN_TARGET} for single-corpus runs and {COMPOSITE_TOKEN_TARGET} for the composite " + "mixture." + ), + ) + parser.add_argument( + "--num-train-steps", + type=int, + default=None, + help="Number of training steps to run (default: computed from --token-target, --global-batch-size, --seq-len).", + ) + parser.add_argument("--warmup-steps", type=int, default=WARMUP_STEPS) + parser.add_argument("--run-suffix", type=str, default=None) + parser.add_argument("--wandb-group", type=str, default=None) + parser.add_argument("--use-default-validation", action="store_true") + parser.add_argument( + "--eval-suite", + choices=tuple(_EVAL_SUITES.keys()), + default=_DEFAULT_EVAL_SUITE, + help="Eval-harness suite to run (during training, post-training, or both).", + ) + parser.add_argument( + "--eval-suite-mode", + choices=("post_train", "during_train", "both"), + default=_DEFAULT_EVAL_SUITE_MODE, + help="When to run eval-harness: post_train, during_train, or both.", + ) + parser.add_argument( + "--steps-per-task-eval", + type=int, + default=STEPS_PER_EVAL, + help="How often to run eval-harness tasks during training when eval-suite-mode includes during_train.", + ) + parser.add_argument( + "--permutation-type", + choices=("feistel", "linear"), + default="feistel", + help="Shuffle permutation type for mixture datasets.", + ) + parser.add_argument( + "--dataset-tokenizer", + type=str, + default="stanford-crfm/marin-tokenizer", + help=( + "Tokenizer spec for loading vocab size/special ids (does not retokenize the dataset). " + "Must match the tokenizer used when pretokenizing." + ), + ) + parser.add_argument("--single-checkpoint", action="store_true") + parser.add_argument("--checkpoint-save-minutes", type=int, default=60) + parser.add_argument( + "--cross-entropy-block-size", + type=int, + default=512, + help=( + "Vocab block size for the fused next-token loss. Smaller blocks reduce peak memory at the cost of " + "more blocks. Use a multiple of 128 when using the Pallas kernel." + ), + ) + parser.add_argument( + "--cross-entropy-implementation", + choices=("auto", "xla", "pallas_tpu", "reference"), + default="auto", + help=( + "Cross-entropy backend. 'auto' tries Pallas on TPU v5+ and falls back to XLA when unsupported (e.g. TPU v4)." + ), + ) + parser.set_defaults(explicit_mesh_axes=True) + parser.add_argument( + "--explicit-mesh-axes", + dest="explicit_mesh_axes", + action="store_true", + help="Use explicit mesh axes in TrainerConfig (default).", + ) + parser.add_argument( + "--no-explicit-mesh-axes", + dest="explicit_mesh_axes", + action="store_false", + help="Disable explicit mesh axes in TrainerConfig.", + ) + + parser.set_defaults(legacy_axis_resources=False) + parser.set_defaults(use_gmm=True) + parser.add_argument( + "--use-gmm", + dest="use_gmm", + action="store_true", + help="Use Megablox GMM for expert matmuls (default).", + ) + parser.add_argument( + "--no-use-gmm", + dest="use_gmm", + action="store_false", + help="Use ragged_dot_general for expert matmuls (debug/ablation).", + ) + + parser.add_argument( + "--legacy-axis-resources", + dest="legacy_axis_resources", + action="store_true", + help=( + "Use a December-style axis mapping equivalent to axis_resources with " + "token/token_repeat/batch -> (replica, data) and embed -> data." + ), + ) + + return parser.parse_args() + + +def _patch_trainer_sharding_ablations( + train_step: ExecutorStep, + *, + explicit_mesh_axes: bool, + legacy_axis_resources: bool, +) -> ExecutorStep: + config = train_step.config + inner = config.train_config + trainer = inner.trainer + mesh = trainer.mesh + + if legacy_axis_resources: + mesh = dataclasses.replace( + mesh, + compute_mapping={ + "batch": ("replica", "data"), + "token": ("replica", "data"), + "token_repeat": ("replica", "data"), + }, + param_mapping={"embed": "data"}, + ) + + trainer = dataclasses.replace(trainer, mesh=mesh, use_explicit_mesh_axes=explicit_mesh_axes) + inner = dataclasses.replace(inner, trainer=trainer) + config = dataclasses.replace(config, train_config=inner) + return dataclasses.replace(train_step, config=config) + + +def main() -> None: + args = _parse_args() + tpu_type = args.tpu_type + seq_len = args.seq_len + global_batch_size = args.global_batch_size + num_train_steps_override = args.num_train_steps + eval_suite = args.eval_suite + eval_suite_mode = args.eval_suite_mode + + if args.smoke: + if tpu_type == _DEFAULT_TPU_TYPE: + tpu_type = _SMOKE_TPU_TYPE + if seq_len == _DEFAULT_SEQ_LEN: + seq_len = _SMOKE_SEQ_LEN + if global_batch_size == _DEFAULT_GLOBAL_BATCH_SIZE: + global_batch_size = _SMOKE_GLOBAL_BATCH_SIZE + if eval_suite == _DEFAULT_EVAL_SUITE: + eval_suite = _SMOKE_EVAL_SUITE + if eval_suite_mode == _DEFAULT_EVAL_SUITE_MODE: + eval_suite_mode = _SMOKE_EVAL_SUITE_MODE + if num_train_steps_override is None: + num_train_steps_override = _SMOKE_NUM_TRAIN_STEPS + + if args.cross_entropy_implementation in ("auto", "pallas_tpu") and args.cross_entropy_block_size % 128 != 0: + raise ValueError( + "--cross-entropy-block-size must be a multiple of 128 when using the Pallas kernel " + f"(got {args.cross_entropy_block_size})." + ) + + cross_entropy_implementation = ( + None if args.cross_entropy_implementation == "auto" else args.cross_entropy_implementation + ) + + token_target = args.token_target + if token_target is None: + token_target = COMPOSITE_TOKEN_TARGET if args.dataset == "nemotron_dclm_fineweb_10b" else DEFAULT_TOKEN_TARGET + + num_train_steps = ( + num_train_steps_override + if num_train_steps_override is not None + else _steps_for_token_target(token_target, global_batch_size, seq_len) + ) + + warmup_steps = max(0, int(args.warmup_steps)) + if LR_SCHEDULE == "cosine" and warmup_steps >= num_train_steps: + warmup_steps = max(0, num_train_steps - 1) + + if args.smoke: + # Keep smoke tests fast: the full OLMoE-1B7B geometry has very long compile times and can + # be sensitive to TPU slice health. This smaller shape still exercises the MoE routing + # + ragged expert MLP path. + model_cfg = GrugformerMoeConfig( + max_seq_len=seq_len, + hidden_dim=512, + intermediate_dim=1024, + num_layers=4, + num_heads=8, + num_kv_heads=8, + head_dim=None, + n_routed_experts=8, + num_experts_per_tok=2, + lbl_coef=None, + rzl_coef=None, + router_fp32=True, + router_topk_then_softmax=True, + use_gmm=bool(args.use_gmm), + cross_entropy_block_size=int(args.cross_entropy_block_size), + cross_entropy_implementation=cross_entropy_implementation, + ) + else: + model_cfg = GrugformerMoeConfig( + max_seq_len=seq_len, + hidden_dim=2048, + intermediate_dim=1024, + num_layers=16, + num_heads=16, + num_kv_heads=8, + head_dim=None, + n_routed_experts=64, + num_experts_per_tok=8, + use_gmm=bool(args.use_gmm), + # Avoid XLA allocation-size overflow on large (tokens x vocab) logits tiles. + # 32k blocks are fine for smaller local token counts, but can exceed XLA's 32-bit + # allocation checks at large global batch/seq settings. + cross_entropy_block_size=int(args.cross_entropy_block_size), + cross_entropy_implementation=cross_entropy_implementation, + ) + + tokenized = DATASET_OPTIONS[args.dataset] + if not isinstance(tokenized, LMMixtureDatasetConfig): + raise ValueError( + f"--dataset {args.dataset} is not a mixture dataset; cannot set permutation_type={args.permutation_type}" + ) + tokenized = dataclasses.replace( + tokenized, + permutation_type=args.permutation_type, + tokenizer=args.dataset_tokenizer, + ) + + evals = _EVAL_SUITES[eval_suite] + eval_harness_tasks = () + if eval_suite_mode in ("during_train", "both"): + eval_harness_tasks = evals + + train = SimpleTrainConfig( + resources=ResourceConfig.with_tpu(tpu_type=tpu_type), + train_seq_len=seq_len, + train_batch_size=global_batch_size, + num_train_steps=num_train_steps, + learning_rate=LEARNING_RATE, + weight_decay=WEIGHT_DECAY, + beta1=BETA1, + beta2=BETA2, + epsilon=EPSILON, + max_grad_norm=MAX_GRAD_NORM, + warmup=warmup_steps, + lr_schedule=LR_SCHEDULE, + min_lr_ratio=MIN_LR_RATIO, + z_loss_weight=Z_LOSS_WEIGHT, + steps_per_eval=STEPS_PER_EVAL, + steps_per_export=STEPS_PER_EXPORT, + steps_per_task_eval=args.steps_per_task_eval, + steps_per_hf_export=-1, + per_device_parallelism=int(args.per_device_parallelism), + explicit_mesh_axes=bool(args.explicit_mesh_axes), + ) + + default_suffix = f"grugformer_moe_olmoe1b7b_{tpu_type}_bs{global_batch_size}_{args.dataset}_seq{seq_len}" + run_suffix = args.run_suffix or default_suffix + wandb_group = args.wandb_group if args.wandb_group is not None else os.environ.get("WANDB_GROUP") + + train_step = default_train( + name=f"speedrun/{run_suffix}", + tokenized=tokenized, + model_config=model_cfg, + train_config=train, + tags=[ + "speedrun", + "grugformer_moe", + "olmoe_1b7b", + tpu_type, + f"b{global_batch_size}", + f"s{seq_len}", + f"perm={args.permutation_type}", + f"pdp={int(args.per_device_parallelism)}", + f"explicit_mesh_axes={int(bool(args.explicit_mesh_axes))}", + f"legacy_axis_resources={int(bool(args.legacy_axis_resources))}", + ], + eval_harness_tasks=eval_harness_tasks, + wandb_name=run_suffix, + wandb_group=wandb_group, + use_default_validation=args.use_default_validation, + checkpointer_save_interval=timedelta(minutes=int(args.checkpoint_save_minutes)), + checkpointer_keep=[] if args.single_checkpoint else None, + ) + train_step = _patch_trainer_sharding_ablations( + train_step, + explicit_mesh_axes=bool(args.explicit_mesh_axes), + legacy_axis_resources=bool(args.legacy_axis_resources), + ) + + steps: list[ExecutorStep] = [train_step] + if eval_suite_mode in ("post_train", "both") and eval_suite != "none": + steps.append( + ExecutorStep( + name=f"evaluation/levanter_eval_harness/{run_suffix}/{eval_suite}", + fn=run_levanter_checkpoint_eval_harness, + config=LevanterEvalHarnessStepConfig( + model_name=f"{run_suffix}_{eval_suite}", + model_config=model_cfg, + tokenizer=args.dataset_tokenizer, + checkpoint_root=train_step / "checkpoints", + evals=evals, + max_eval_instances=None, + output_path=output_path_of(train_step, f"eval_harness/{eval_suite}"), + wandb_project=os.environ.get("WANDB_PROJECT") or "marin", + wandb_group=wandb_group, + ), + resources=ResourceConfig.with_tpu(tpu_type), + pip_dependency_groups=["tpu", "eval"], + ) + ) + + # `executor_main` is draccus-wrapped to provide a standardized CLI. This experiment uses argparse + # for its own runtime flags, so we call the undecorated function to avoid draccus attempting to + # parse the training arguments. + executor_main.__wrapped__( + ExecutorMainConfig(prefix=os.environ.get("MARIN_PREFIX")), + steps=steps, + description="Grugformer MoE (OLMoE-1B7B shape) on Nemotron+DCLM+FineWeb (feistel + eval harness).", + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b.py b/experiments/speedrun/mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b.py new file mode 100644 index 0000000000..03d4562d52 --- /dev/null +++ b/experiments/speedrun/mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b.py @@ -0,0 +1,836 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +# nodryrun +"""Launch a Mixtral-8x7B vs Llama-13B comparison run (100B tokens). + +This script trains both models from scratch on the composite Nemotron + DCLM + FineWeb-Edu mixture, +with: +- feistel shuffling (default), +- Levanter default validation enabled (Paloma + uncheatable), +- eval-harness "core" suite (during + post train by default), +- AdamW-style optimizer (Adam + weight decay), lr=1e-4 (default), +- microbatching via `trainer.per_device_parallelism` (default 1) to reduce peak memory. + +Intended usage is via: +`python -m marin.run.ray_run ... -- python -m experiments.speedrun.mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b ...`. +""" + +from __future__ import annotations + +import argparse +import dataclasses +import json +import os +from datetime import timedelta + +import fsspec +import jmp +from experiments.defaults import default_train +from experiments.evals.task_configs import CORE_TASKS, CORE_TASKS_PLUS_LEADERBOARD, CORE_TASKS_PLUS_MMLU +from experiments.llama import llama_13b +from experiments.simple_train_config import SimpleTrainConfig +from experiments.speedrun.custom_mixtral import MixtralConfig +from experiments.speedrun.olmoe_1b7b_nemotron_40b import DATASET_OPTIONS +from fray.cluster import ResourceConfig +from levanter.checkpoint import discover_latest_checkpoint +from levanter.data.text import LMMixtureDatasetConfig +from levanter.distributed import RayConfig +from levanter.eval_harness import EvalHarnessMainConfig, LmEvalHarnessConfig, run_eval_harness_main +from levanter.optim import AdamConfig +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig +from marin.execution.executor import ExecutorMainConfig, ExecutorStep, executor_main, output_path_of + +from experiments.evals.task_configs import convert_to_levanter_task_config + +DEFAULT_TPU_TYPE = "v5p-32" +DEFAULT_SEQ_LEN = 4096 +DEFAULT_GLOBAL_BATCH_SIZE = 192 +DEFAULT_TOKEN_TARGET = 100_000_000_000 # 100B tokens + +DEFAULT_LEARNING_RATE = 1e-4 +DEFAULT_WEIGHT_DECAY = 0.1 +DEFAULT_BETA1 = 0.9 +DEFAULT_BETA2 = 0.95 +DEFAULT_EPSILON = 1e-8 +DEFAULT_MAX_GRAD_NORM = 1.0 +DEFAULT_WARMUP_STEPS = 2000 +LR_SCHEDULE = "cosine" +MIN_LR_RATIO = 0.125 +Z_LOSS_WEIGHT = 1e-4 +DEFAULT_MIXTRAL_LBL_COEF = 0.01 +DEFAULT_MIXTRAL_RZL_COEF = 0.001 +STEPS_PER_EVAL = 5000 +STEPS_PER_EXPORT = 20_000 + +DEFAULT_EVAL_SUITE = "core" +DEFAULT_EVAL_SUITE_MODE = "both" +DEFAULT_STEPS_PER_TASK_EVAL = 5000 + +MODEL_LLAMA_13B = "llama_13b" +MODEL_MIXTRAL_8X7B = "mixtral_8x7b" + +_FORWARDED_ENV_PREFIXES = ("JAX_", "LIBTPU_", "XLA_", "WANDB_", "HF_") +_FORWARDED_ENV_KEYS = ( + "PIP_IGNORE_INSTALLED", + "PIP_NO_CACHE_DIR", + "RAY_TMPDIR", + "TMPDIR", +) +_MAXTEXT_V5P_LIBTPU_INIT_ARGS = ( + "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true " + "--xla_tpu_megacore_fusion_allow_ags=false " + "--xla_enable_async_collective_permute=true " + "--xla_tpu_enable_ag_backward_pipelining=true " + "--xla_tpu_enable_data_parallel_all_reduce_opt=true " + "--xla_tpu_data_parallel_opt_different_sized_ops=true " + "--xla_tpu_enable_async_collective_fusion=true " + "--xla_tpu_enable_async_collective_fusion_multiple_steps=true " + "--xla_tpu_overlap_compute_collective_tc=true " + "--xla_enable_async_all_gather=true" +) + +_EVAL_SUITES: dict[str, tuple] = { + "none": (), + "core": CORE_TASKS, + "core_plus_mmlu": CORE_TASKS_PLUS_MMLU, + "core_plus_leaderboard": CORE_TASKS_PLUS_LEADERBOARD, +} + + +@dataclasses.dataclass(frozen=True) +class LevanterEvalHarnessStepConfig: + """Config for running Levanter's eval-harness on a Levanter (non-HF) checkpoint.""" + + model_name: str + model_config: object + tokenizer: str + checkpoint_root: str + evals: tuple + max_eval_instances: int | None + output_path: str + wandb_project: str + apply_chat_template: bool = False + wandb_group: str | None = None + + +def run_levanter_checkpoint_eval_harness(config: LevanterEvalHarnessStepConfig) -> None: + checkpoint_path = discover_latest_checkpoint(config.checkpoint_root) + if checkpoint_path is None: + raise ValueError(f"No checkpoints found under {config.checkpoint_root}") + + trainer_config = TrainerConfig( + tracker=WandbConfig( + entity=os.environ.get("WANDB_ENTITY"), + project=config.wandb_project, + tags=["eval_harness"], + name=config.model_name, + group=config.wandb_group, + mode=os.environ.get("WANDB_MODE"), + ), + mp=jmp.get_policy("p=f32,c=bfloat16"), + per_device_eval_parallelism=1, + ray=RayConfig(auto_start_cluster=False), + ) + + eval_config = EvalHarnessMainConfig( + eval_harness=LmEvalHarnessConfig( + task_spec=convert_to_levanter_task_config(config.evals), + max_examples=config.max_eval_instances, + log_samples=False, + confirm_run_unsafe_code=True, + ), + tokenizer=config.tokenizer, + checkpoint_path=checkpoint_path, + checkpoint_is_hf=False, + apply_chat_template=config.apply_chat_template, + trainer=trainer_config, + model=config.model_config, # type: ignore[arg-type] + ) + + results = run_eval_harness_main(eval_config) + + fs = fsspec.filesystem("gcs") if config.output_path.startswith("gs://") else fsspec.filesystem("file") + output_path = config.output_path.rstrip("/") + "/results.json" + with fs.open(output_path, "w") as f: + json.dump(results, f, indent=2, default=str) + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def _steps_for_token_target(token_target: int, global_batch_size: int, seq_len: int) -> int: + return _ceil_div(token_target, global_batch_size * seq_len) + + +def _build_mixtral_8x7b_config( + *, + seq_len: int, + max_seq_len: int, + use_gmm: bool, + use_qk_norm: bool, + router_topk_then_softmax: bool, + router_fp32: bool, + flash_attention_block_size: int | None, + cross_entropy_block_size: int | None, + cross_entropy_b_block_size: int | None, + cross_entropy_h_block_size: int | None, + cross_entropy_implementation: str | None, + lbl_coef: float | None, + rzl_coef: float | None, +) -> MixtralConfig: + return MixtralConfig( + seq_len=seq_len, + max_seq_len=max_seq_len, + hidden_dim=4096, + intermediate_dim=14336, + num_layers=32, + num_heads=32, + num_kv_heads=8, + n_routed_experts=8, + num_experts_per_tok=2, + layer_norm_epsilon=1e-5, + gradient_checkpointing=True, + scan_layers=True, + use_gmm=use_gmm, + use_qk_norm=use_qk_norm, + router_topk_then_softmax=router_topk_then_softmax, + router_fp32=router_fp32, + flash_attention_block_size=flash_attention_block_size, + lbl_coef=lbl_coef, + rzl_coef=rzl_coef, + cross_entropy_block_size=cross_entropy_block_size, + cross_entropy_b_block_size=cross_entropy_b_block_size, + cross_entropy_h_block_size=cross_entropy_h_block_size, + cross_entropy_implementation=cross_entropy_implementation, + ) + + +def _patch_per_device_parallelism( + train_step: ExecutorStep, + *, + per_device_parallelism: int | None, +) -> ExecutorStep: + if per_device_parallelism is None: + return train_step + if per_device_parallelism <= 0: + raise ValueError("--per-device-parallelism must be >= 1") + + cfg = train_step.config + inner = cfg.train_config + trainer = dataclasses.replace(inner.trainer, per_device_parallelism=per_device_parallelism) + inner = dataclasses.replace(inner, trainer=trainer) + cfg = dataclasses.replace(cfg, train_config=inner) + return dataclasses.replace(train_step, config=cfg) + + +def _collect_forwarded_runtime_env() -> dict[str, str]: + forwarded: dict[str, str] = {} + for key, value in os.environ.items(): + if key in _FORWARDED_ENV_KEYS or any(key.startswith(prefix) for prefix in _FORWARDED_ENV_PREFIXES): + if value: + forwarded[key] = value + return forwarded + + +def _patch_train_step_env_vars(train_step: ExecutorStep, *, env_vars: dict[str, str]) -> ExecutorStep: + if not env_vars: + return train_step + + config = train_step.config + patched_env = dict(getattr(config, "env_vars", None) or {}) + changed = False + for key, value in env_vars.items(): + if patched_env.get(key) != value: + patched_env[key] = value + changed = True + + if not changed: + return train_step + + return dataclasses.replace(train_step, config=dataclasses.replace(config, env_vars=patched_env)) + + +def _default_libtpu_init_args_for_tpu(tpu_type: str) -> str | None: + if tpu_type.startswith("v5p-"): + return _MAXTEXT_V5P_LIBTPU_INIT_ARGS + return None + + +def _parse_bool(value: str) -> bool: + lowered = value.strip().lower() + if lowered in ("1", "true", "t", "yes", "y", "on"): + return True + if lowered in ("0", "false", "f", "no", "n", "off"): + return False + raise argparse.ArgumentTypeError(f"Expected a boolean value, got {value!r}") + + +def _patch_trainer_profiler_perfetto_link( + train_step: ExecutorStep, + *, + profiler_perfetto_link: bool, +) -> ExecutorStep: + if not profiler_perfetto_link: + return train_step + config = train_step.config + inner = config.train_config + trainer = dataclasses.replace(inner.trainer, profiler_perfetto_link=True) + inner = dataclasses.replace(inner, trainer=trainer) + config = dataclasses.replace(config, train_config=inner) + return dataclasses.replace(train_step, config=config) + + +def _patch_trainer_sharding_ablations( + train_step: ExecutorStep, + *, + explicit_mesh_axes: bool, + legacy_axis_resources: bool, +) -> ExecutorStep: + config = train_step.config + inner = config.train_config + trainer = inner.trainer + mesh = trainer.mesh + + if legacy_axis_resources: + mesh = dataclasses.replace( + mesh, + compute_mapping={ + "batch": ("replica", "data"), + "token": ("replica", "data"), + "token_repeat": ("replica", "data"), + }, + param_mapping={"embed": "data"}, + ) + + trainer = dataclasses.replace(trainer, mesh=mesh, use_explicit_mesh_axes=explicit_mesh_axes) + inner = dataclasses.replace(inner, trainer=trainer) + config = dataclasses.replace(config, train_config=inner) + return dataclasses.replace(train_step, config=config) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--tpu-type", default=DEFAULT_TPU_TYPE) + parser.add_argument("--seq-len", type=int, default=DEFAULT_SEQ_LEN) + parser.add_argument( + "--train-seq-len", + type=int, + default=None, + help=( + "Training sequence length. Defaults to --seq-len. This is useful when using sliding-window attention " + "where the model attention window (--seq-len) can be smaller than the packed training sequences." + ), + ) + parser.add_argument("--global-batch-size", type=int, default=DEFAULT_GLOBAL_BATCH_SIZE) + parser.add_argument("--token-target", type=int, default=DEFAULT_TOKEN_TARGET) + parser.add_argument( + "--models", + choices=(MODEL_LLAMA_13B, MODEL_MIXTRAL_8X7B, "both"), + default="both", + help="Which model(s) to run. Use this to submit separate Ray jobs per model for stability.", + ) + parser.add_argument( + "--num-train-steps", + type=int, + default=None, + help="If omitted, computed from --token-target / (--global-batch-size * --seq-len).", + ) + parser.add_argument("--learning-rate", type=float, default=DEFAULT_LEARNING_RATE) + parser.add_argument("--weight-decay", type=float, default=DEFAULT_WEIGHT_DECAY) + parser.add_argument("--beta1", type=float, default=DEFAULT_BETA1) + parser.add_argument("--beta2", type=float, default=DEFAULT_BETA2) + parser.add_argument("--epsilon", type=float, default=DEFAULT_EPSILON) + parser.add_argument("--max-grad-norm", type=float, default=DEFAULT_MAX_GRAD_NORM) + parser.add_argument("--warmup-steps", type=int, default=DEFAULT_WARMUP_STEPS) + parser.add_argument( + "--optimizer", + choices=("adamw", "adamc"), + default="adamw", + help="Optimizer mode. `adamc` enables AdamC-corrected weight decay.", + ) + parser.add_argument( + "--per-device-parallelism", + type=int, + default=1, + help=( + "Microbatch size per device. Use 1 to reduce peak memory; Levanter will use gradient accumulation to " + "reach the requested global batch size." + ), + ) + parser.set_defaults(explicit_mesh_axes=False) + parser.add_argument( + "--explicit-mesh-axes", + dest="explicit_mesh_axes", + action="store_true", + help="Enable explicit mesh axes in TrainerConfig.", + ) + parser.add_argument( + "--no-explicit-mesh-axes", + dest="explicit_mesh_axes", + action="store_false", + help="Disable explicit mesh axes in TrainerConfig (default).", + ) + parser.set_defaults(legacy_axis_resources=True) + parser.add_argument( + "--legacy-axis-resources", + dest="legacy_axis_resources", + action="store_true", + help=( + "Use a December-style axis mapping equivalent to axis_resources with " + "token/token_repeat/batch -> (replica, data) and embed -> data." + ), + ) + parser.add_argument( + "--no-legacy-axis-resources", + dest="legacy_axis_resources", + action="store_false", + help="Use the current mesh compute mapping path (default uses legacy mapping).", + ) + + parser.add_argument( + "--permutation-type", + choices=("feistel", "linear"), + default="feistel", + help="Shuffle permutation type for the mixture dataset.", + ) + parser.add_argument( + "--dataset", + choices=tuple(DATASET_OPTIONS.keys()), + default="nemotron_dclm_fineweb_10b", + help="Dataset preset from DATASET_OPTIONS (e.g. nemotron_cc).", + ) + parser.add_argument( + "--dataset-tokenizer", + type=str, + default="stanford-crfm/marin-tokenizer", + help=( + "Tokenizer spec for vocab size / special ids / eval decoding (does not retokenize). Must match the " + "tokenizer used when building the tokenized dataset." + ), + ) + + parser.add_argument("--wandb-project", type=str, default="mixtral_vs_dense") + parser.add_argument("--wandb-group", type=str, default=None) + parser.add_argument("--run-suffix", type=str, default=None) + parser.add_argument("--extra-tag", action="append", default=[], help="Additional W&B tag (repeatable).") + + parser.add_argument( + "--disable-default-validation", + action="store_true", + help="Disable default Levanter validation losses (Paloma + uncheatable).", + ) + parser.add_argument( + "--steps-per-eval", + type=int, + default=STEPS_PER_EVAL, + help="How often (in steps) to run default Levanter validation losses when enabled.", + ) + parser.add_argument( + "--eval-suite", + choices=tuple(_EVAL_SUITES.keys()), + default=DEFAULT_EVAL_SUITE, + help="Eval-harness suite to run (during training, post-training, or both).", + ) + parser.add_argument( + "--eval-suite-mode", + choices=("post_train", "during_train", "both"), + default=DEFAULT_EVAL_SUITE_MODE, + help="When to run eval-harness: post_train, during_train, or both.", + ) + parser.add_argument( + "--steps-per-task-eval", + type=int, + default=DEFAULT_STEPS_PER_TASK_EVAL, + help="How often to run eval-harness tasks during training when eval-suite-mode includes during_train.", + ) + + parser.add_argument( + "--trainer.profiler", + dest="trainer_profiler", + type=_parse_bool, + default=False, + help="Enable the JAX profiler (writes traces under ./logs//profiler and uploads to W&B).", + ) + parser.add_argument( + "--trainer.profiler_start_step", + dest="trainer_profiler_start_step", + type=int, + default=5, + help="Step to start profiling (Levanter TrainerConfig.profiler_start_step).", + ) + parser.add_argument( + "--trainer.profiler_num_steps", + dest="trainer_profiler_num_steps", + type=int, + default=100, + help="Number of steps to capture once profiling starts (Levanter TrainerConfig.profiler_num_steps).", + ) + parser.add_argument( + "--trainer.profiler_perfetto_link", + dest="trainer_profiler_perfetto_link", + type=_parse_bool, + default=False, + help="Generate a Perfetto link when stopping the profiler (see lib/levanter/docs/Performance-Guide.md).", + ) + + parser.set_defaults(mixtral_use_gmm=True) + parser.add_argument( + "--mixtral-use-gmm", + dest="mixtral_use_gmm", + action="store_true", + help="Use Megablox/GMM MoE kernels for Mixtral (default).", + ) + parser.add_argument( + "--mixtral-no-gmm", + dest="mixtral_use_gmm", + action="store_false", + help="Disable Megablox/GMM MoE kernels for Mixtral.", + ) + parser.add_argument( + "--mixtral-flash-attention-block-size", + type=int, + default=None, + help="Flash-attention block size for Mixtral. Set to <=0 to use auto block sizing.", + ) + parser.add_argument( + "--mixtral-cross-entropy-block-size", + type=int, + default=1024, + help=( + "Vocab block size for Mixtral fused next-token loss (default: 1024). " "Set <=0 to disable fused block loss." + ), + ) + parser.add_argument( + "--mixtral-cross-entropy-b-block-size", + type=int, + default=None, + help="Batch tile size for Mixtral fused CE Pallas kernel (multiple of 128; TPU v5p typically needs >=1024).", + ) + parser.add_argument( + "--mixtral-cross-entropy-h-block-size", + type=int, + default=None, + help="Hidden tile size for Mixtral fused CE Pallas kernel (multiple of 128).", + ) + parser.add_argument( + "--mixtral-cross-entropy-implementation", + choices=("auto", "legacy", "xla", "pallas_tpu", "reference"), + default="legacy", + help=( + "Backend for Mixtral next-token loss. `legacy` uses the December-era blockwise CE (custom_vjp); " + "`auto` tries Pallas first." + ), + ) + parser.add_argument( + "--mixtral-use-qk-norm", + action="store_true", + help="Enable Mixtral QK normalization.", + ) + parser.add_argument( + "--mixtral-router-topk-then-softmax", + action="store_true", + help="Enable top-k-then-softmax routing in Mixtral.", + ) + parser.add_argument( + "--mixtral-router-fp32", + action="store_true", + help="Compute Mixtral router/gating math in fp32.", + ) + parser.add_argument( + "--mixtral-lbl-coef", + type=float, + default=DEFAULT_MIXTRAL_LBL_COEF, + help="Mixtral router auxiliary load-balancing loss coefficient.", + ) + parser.add_argument( + "--mixtral-rzl-coef", + type=float, + default=DEFAULT_MIXTRAL_RZL_COEF, + help="Mixtral router z-loss coefficient.", + ) + parser.set_defaults(use_maxtext_libtpu_flags=False) + parser.add_argument( + "--use-maxtext-libtpu-flags", + dest="use_maxtext_libtpu_flags", + action="store_true", + help="Use MaxText-style LIBTPU_INIT_ARGS on v5p when not explicitly set.", + ) + parser.add_argument( + "--no-maxtext-libtpu-flags", + dest="use_maxtext_libtpu_flags", + action="store_false", + help="Do not auto-set MaxText-style LIBTPU_INIT_ARGS.", + ) + parser.add_argument( + "--libtpu-init-args", + type=str, + default=None, + help="Explicit LIBTPU_INIT_ARGS override for train steps.", + ) + + parser.add_argument("--single-checkpoint", action="store_true") + parser.add_argument("--checkpoint-save-minutes", type=int, default=60) + parser.add_argument( + "--max-concurrent", + type=int, + default=None, + help="Maximum number of experiment steps to run concurrently. Use 1 for sequential smoke runs.", + ) + args = parser.parse_args() + forwarded_env = _collect_forwarded_runtime_env() + if args.libtpu_init_args: + forwarded_env["LIBTPU_INIT_ARGS"] = args.libtpu_init_args.strip() + elif args.use_maxtext_libtpu_flags and "LIBTPU_INIT_ARGS" not in forwarded_env: + default_libtpu_init_args = _default_libtpu_init_args_for_tpu(args.tpu_type) + if default_libtpu_init_args is not None: + forwarded_env["LIBTPU_INIT_ARGS"] = default_libtpu_init_args + + use_default_validation = not args.disable_default_validation + + train_seq_len = int(args.train_seq_len) if args.train_seq_len is not None else int(args.seq_len) + + num_train_steps = ( + int(args.num_train_steps) + if args.num_train_steps is not None + else _steps_for_token_target(args.token_target, args.global_batch_size, train_seq_len) + ) + + warmup_steps = max(0, int(args.warmup_steps)) + if LR_SCHEDULE == "cosine" and warmup_steps >= num_train_steps: + warmup_steps = max(0, num_train_steps - 1) + + tokenized = DATASET_OPTIONS[args.dataset] + if not isinstance(tokenized, LMMixtureDatasetConfig): + raise ValueError(f"Expected {args.dataset} to be a mixture dataset config") + tokenized = dataclasses.replace( + tokenized, + permutation_type=args.permutation_type, + tokenizer=args.dataset_tokenizer, + ) + + evals = _EVAL_SUITES[args.eval_suite] + eval_harness_tasks = () + if args.eval_suite_mode in ("during_train", "both"): + eval_harness_tasks = evals + + optimizer_cfg = AdamConfig( + learning_rate=float(args.learning_rate), + weight_decay=float(args.weight_decay), + beta1=float(args.beta1), + beta2=float(args.beta2), + epsilon=float(args.epsilon), + max_grad_norm=float(args.max_grad_norm), + warmup=float(warmup_steps), + lr_schedule=LR_SCHEDULE, + min_lr_ratio=float(MIN_LR_RATIO), + adamc_weight_decay=bool(args.optimizer == "adamc"), + ) + + base_train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu(tpu_type=args.tpu_type), + train_seq_len=train_seq_len, + train_batch_size=args.global_batch_size, + num_train_steps=num_train_steps, + learning_rate=float(args.learning_rate), + optimizer_config=optimizer_cfg, + z_loss_weight=Z_LOSS_WEIGHT, + steps_per_eval=int(args.steps_per_eval), + steps_per_export=STEPS_PER_EXPORT, + steps_per_task_eval=int(args.steps_per_task_eval), + steps_per_hf_export=-1, + explicit_mesh_axes=bool(args.explicit_mesh_axes), + profiler=bool(args.trainer_profiler), + profiler_start_step=int(args.trainer_profiler_start_step), + profiler_num_steps=int(args.trainer_profiler_num_steps), + ) + + run_suffix = args.run_suffix + if not run_suffix: + raise ValueError( + "--run-suffix is required to ensure a fresh output path (avoids accidentally resuming prior runs)." + ) + + wandb_group = args.wandb_group if args.wandb_group is not None else os.environ.get("WANDB_GROUP") + + def _make_tags(*, model_name: str) -> list[str]: + return [ + "exp=mixtral_vs_dense", + f"data={args.dataset}", + f"model={model_name}", + f"token_target={args.token_target}", + f"perm={args.permutation_type}", + f"seq={args.seq_len}", + f"bs={args.global_batch_size}", + f"pdp={args.per_device_parallelism}", + f"explicit_mesh_axes={int(bool(args.explicit_mesh_axes))}", + f"legacy_axis_resources={int(bool(args.legacy_axis_resources))}", + f"eval_suite={args.eval_suite}", + f"eval_mode={args.eval_suite_mode}", + f"mixtral_use_gmm={int(args.mixtral_use_gmm)}", + f"mixtral_ce_impl={args.mixtral_cross_entropy_implementation}", + f"mixtral_ce_block={args.mixtral_cross_entropy_block_size}", + f"mixtral_ce_b_block={args.mixtral_cross_entropy_b_block_size}", + f"mixtral_ce_h_block={args.mixtral_cross_entropy_h_block_size}", + f"mixtral_qk_norm={int(args.mixtral_use_qk_norm)}", + f"mixtral_router_topk_then_softmax={int(args.mixtral_router_topk_then_softmax)}", + f"mixtral_router_fp32={int(args.mixtral_router_fp32)}", + f"mixtral_lbl_coef={args.mixtral_lbl_coef:.3g}", + f"mixtral_rzl_coef={args.mixtral_rzl_coef:.3g}", + ( + f"opt=adamc_b{args.beta1:.2f}_{args.beta2:.2f}" + if args.optimizer == "adamc" + else f"opt=adamw_b{args.beta1:.2f}_{args.beta2:.2f}" + ), + f"lr={args.learning_rate:.2e}", + *list(args.extra_tag), + ] + + llama_cfg = dataclasses.replace(llama_13b, max_seq_len=args.seq_len) + mixtral_ce_impl = ( + None if args.mixtral_cross_entropy_implementation == "auto" else args.mixtral_cross_entropy_implementation + ) + mixtral_ce_block_size = ( + int(args.mixtral_cross_entropy_block_size) + if args.mixtral_cross_entropy_block_size is not None and int(args.mixtral_cross_entropy_block_size) > 0 + else None + ) + if mixtral_ce_block_size is None and ( + args.mixtral_cross_entropy_b_block_size is not None or args.mixtral_cross_entropy_h_block_size is not None + ): + raise ValueError( + "--mixtral-cross-entropy-b-block-size/--mixtral-cross-entropy-h-block-size require " + "--mixtral-cross-entropy-block-size > 0." + ) + mixtral_cfg = _build_mixtral_8x7b_config( + seq_len=args.seq_len, + max_seq_len=train_seq_len, + use_gmm=bool(args.mixtral_use_gmm), + use_qk_norm=bool(args.mixtral_use_qk_norm), + router_topk_then_softmax=bool(args.mixtral_router_topk_then_softmax), + router_fp32=bool(args.mixtral_router_fp32), + flash_attention_block_size=( + int(args.mixtral_flash_attention_block_size) + if args.mixtral_flash_attention_block_size is not None and int(args.mixtral_flash_attention_block_size) > 0 + else None + ), + cross_entropy_block_size=mixtral_ce_block_size, + cross_entropy_b_block_size=( + int(args.mixtral_cross_entropy_b_block_size) + if args.mixtral_cross_entropy_b_block_size is not None and int(args.mixtral_cross_entropy_b_block_size) > 0 + else None + ), + cross_entropy_h_block_size=( + int(args.mixtral_cross_entropy_h_block_size) + if args.mixtral_cross_entropy_h_block_size is not None and int(args.mixtral_cross_entropy_h_block_size) > 0 + else None + ), + cross_entropy_implementation=mixtral_ce_impl, + lbl_coef=float(args.mixtral_lbl_coef) if args.mixtral_lbl_coef > 0 else None, + rzl_coef=float(args.mixtral_rzl_coef) if args.mixtral_rzl_coef > 0 else None, + ) + + llama_name = f"{MODEL_LLAMA_13B}_{run_suffix}" + mixtral_name = f"{MODEL_MIXTRAL_8X7B}_{run_suffix}" + + selected = [] + if args.models in ("both", MODEL_LLAMA_13B): + llama_train_step = default_train( + name=f"mixtral_vs_dense/{MODEL_LLAMA_13B}/{run_suffix}", + tokenized=tokenized, + model_config=llama_cfg, + train_config=base_train_config, + tags=_make_tags(model_name=MODEL_LLAMA_13B), + eval_harness_tasks=eval_harness_tasks, + wandb_name=llama_name, + wandb_group=wandb_group, + wandb_project=args.wandb_project, + use_default_validation=use_default_validation, + checkpointer_save_interval=timedelta(minutes=int(args.checkpoint_save_minutes)), + checkpointer_keep=[] if args.single_checkpoint else None, + ) + llama_train_step = _patch_per_device_parallelism( + llama_train_step, + per_device_parallelism=args.per_device_parallelism, + ) + llama_train_step = _patch_train_step_env_vars(llama_train_step, env_vars=forwarded_env) + llama_train_step = _patch_trainer_profiler_perfetto_link( + llama_train_step, + profiler_perfetto_link=bool(args.trainer_profiler_perfetto_link), + ) + llama_train_step = _patch_trainer_sharding_ablations( + llama_train_step, + explicit_mesh_axes=bool(args.explicit_mesh_axes), + legacy_axis_resources=bool(args.legacy_axis_resources), + ) + selected.append((MODEL_LLAMA_13B, llama_cfg, llama_train_step)) + + if args.models in ("both", MODEL_MIXTRAL_8X7B): + mixtral_train_step = default_train( + name=f"mixtral_vs_dense/{MODEL_MIXTRAL_8X7B}/{run_suffix}", + tokenized=tokenized, + model_config=mixtral_cfg, + train_config=base_train_config, + tags=_make_tags(model_name=MODEL_MIXTRAL_8X7B), + eval_harness_tasks=eval_harness_tasks, + wandb_name=mixtral_name, + wandb_group=wandb_group, + wandb_project=args.wandb_project, + use_default_validation=use_default_validation, + checkpointer_save_interval=timedelta(minutes=int(args.checkpoint_save_minutes)), + checkpointer_keep=[] if args.single_checkpoint else None, + ) + mixtral_train_step = _patch_per_device_parallelism( + mixtral_train_step, + per_device_parallelism=args.per_device_parallelism, + ) + mixtral_train_step = _patch_train_step_env_vars(mixtral_train_step, env_vars=forwarded_env) + mixtral_train_step = _patch_trainer_profiler_perfetto_link( + mixtral_train_step, + profiler_perfetto_link=bool(args.trainer_profiler_perfetto_link), + ) + mixtral_train_step = _patch_trainer_sharding_ablations( + mixtral_train_step, + explicit_mesh_axes=bool(args.explicit_mesh_axes), + legacy_axis_resources=bool(args.legacy_axis_resources), + ) + selected.append((MODEL_MIXTRAL_8X7B, mixtral_cfg, mixtral_train_step)) + + steps: list[ExecutorStep] = [train_step for _, _, train_step in selected] + if args.eval_suite_mode in ("post_train", "both") and args.eval_suite != "none": + for model_name, model_cfg, train_step in selected: + steps.append( + ExecutorStep( + name=f"evaluation/levanter_eval_harness/{model_name}/{run_suffix}/{args.eval_suite}", + fn=run_levanter_checkpoint_eval_harness, + config=LevanterEvalHarnessStepConfig( + model_name=f"{model_name}_{run_suffix}_{args.eval_suite}", + model_config=model_cfg, + tokenizer=args.dataset_tokenizer, + checkpoint_root=train_step / "checkpoints", + evals=evals, + max_eval_instances=None, + output_path=output_path_of(train_step, f"eval_harness/{args.eval_suite}"), + wandb_project=args.wandb_project, + wandb_group=wandb_group, + ), + resources=ResourceConfig.with_tpu(args.tpu_type), + pip_dependency_groups=["tpu", "eval"], + ) + ) + + executor_main.__wrapped__( + ExecutorMainConfig(prefix=os.environ.get("MARIN_PREFIX"), max_concurrent=args.max_concurrent), + steps=steps, + description=( + "Mixtral 8x7B vs Llama 13B " + f"(dataset={args.dataset}, perm={args.permutation_type}, eval_suite={args.eval_suite}, " + f"eval_mode={args.eval_suite_mode}, token_target={args.token_target})." + ), + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py index ebace8695a..3f705a7cf8 100644 --- a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py +++ b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py @@ -8,17 +8,32 @@ import math import os import sys +import json + +import fsspec +import jmp from collections.abc import Sequence from experiments.defaults import default_train +from experiments.evals.task_configs import convert_to_levanter_task_config from experiments.pretraining_datasets import NEMOTRON_WEIGHTS, tokenize_nemotron -from experiments.pretraining_datasets.dclm import dclm_mixture_config_llama3 +from experiments.pretraining_datasets.dclm import ( + DCLM_MIXTURE_WEIGHTS, + dclm_components_llama3, + dclm_mixture_config_llama3, +) from experiments.llama import llama3_tokenizer from experiments.speedrun.custom_mixtral import MixtralConfig +from experiments.speedrun.prebuilt_caches import fineweb_edu_subcache_10B from experiments.simple_train_config import SimpleTrainConfig from fray.cluster import ResourceConfig from levanter.callbacks.profiler import ProfilerConfig from levanter.infra.cli_helpers import load_config +from levanter.checkpoint import discover_latest_checkpoint +from levanter.distributed import RayConfig +from levanter.eval_harness import EvalHarnessMainConfig, LmEvalHarnessConfig, run_eval_harness_main +from levanter.tracker.wandb import WandbConfig +from levanter.trainer import TrainerConfig from marin.execution.executor import ExecutorStep, InputName, executor_main, output_path_of from marin.processing.tokenize import lm_data_config, lm_mixture_data_config from marin.speedrun.speedrun import Author, SpeedrunConfig, SpeedrunResultsConfig, speedrun_results @@ -26,12 +41,78 @@ logger = logging.getLogger("ray") +# --------------------------------------------------------------------------- +# Levanter eval-harness helper (shared by multiple launchers). +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class LevanterEvalHarnessStepConfig: + """Config for running Levanter's eval-harness on a Levanter (non-HF) checkpoint.""" + + model_name: str + model_config: object + tokenizer: str + checkpoint_root: str + evals: tuple + max_eval_instances: int | None + output_path: str + wandb_project: str + apply_chat_template: bool = False + wandb_group: str | None = None + + +def run_levanter_checkpoint_eval_harness(config: LevanterEvalHarnessStepConfig) -> None: + checkpoint_path = discover_latest_checkpoint(config.checkpoint_root) + if checkpoint_path is None: + raise ValueError(f"No checkpoints found under {config.checkpoint_root}") + + trainer_config = TrainerConfig( + tracker=WandbConfig( + entity=os.environ.get("WANDB_ENTITY"), + project=config.wandb_project, + tags=["eval_harness"], + name=config.model_name, + group=config.wandb_group, + mode=os.environ.get("WANDB_MODE"), + ), + mp=jmp.get_policy("p=f32,c=bfloat16"), + per_device_eval_parallelism=1, + ray=RayConfig(auto_start_cluster=False), + ) + + eval_config = EvalHarnessMainConfig( + eval_harness=LmEvalHarnessConfig( + task_spec=convert_to_levanter_task_config(config.evals), + max_examples=config.max_eval_instances, + log_samples=False, + confirm_run_unsafe_code=True, + ), + tokenizer=config.tokenizer, + checkpoint_path=checkpoint_path, + checkpoint_is_hf=False, + apply_chat_template=config.apply_chat_template, + trainer=trainer_config, + model=config.model_config, # type: ignore[arg-type] + ) + + results = run_eval_harness_main(eval_config) + + fs = fsspec.filesystem("gcs") if config.output_path.startswith("gs://") else fsspec.filesystem("file") + output_path = config.output_path.rstrip("/") + "/results.json" + with fs.open(output_path, "w") as f: + json.dump(results, f, indent=2, default=str) + + # --------------------------------------------------------------------------- # Shared experiment knobs (mirrors the dense baseline for flop matching) # --------------------------------------------------------------------------- SEQ_LEN = 2048 DEFAULT_GLOBAL_BATCH_SIZE = 64 TOKEN_TARGET = 40_000_000_000 # 40B tokens +DEFAULT_TOKEN_TARGET = TOKEN_TARGET +# Composite Nemotron+DCLM+FineWeb mixture budget (used by other launchers). +COMPOSITE_TOKEN_TARGET = 100_000_000_000 # 100B tokens NUM_TRAIN_STEPS = math.ceil(TOKEN_TARGET / (DEFAULT_GLOBAL_BATCH_SIZE * SEQ_LEN)) LEARNING_RATE = 1e-4 WEIGHT_DECAY = 0.1 @@ -105,9 +186,28 @@ def build_model_config(*, model: str, seq_len: int) -> MixtralConfig: weights=NEMOTRON_WEIGHTS, ) + +# Composite mixture: Nemotron CC + DCLM + FineWeb-Edu 10B subcache. +# Weights are expressed in approximate corpus TiB sizes (mirrors existing per-corpus weights). +nemotron_dclm_fineweb_10b_components = { + **nemotron_cc_steps, + **dclm_components_llama3, + "fineweb_edu/10b": fineweb_edu_subcache_10B, +} +nemotron_dclm_fineweb_10b_weights = { + **NEMOTRON_WEIGHTS, + **DCLM_MIXTURE_WEIGHTS, + "fineweb_edu/10b": 0.01, +} +nemotron_dclm_fineweb_10b_mixture = lm_mixture_data_config( + components=nemotron_dclm_fineweb_10b_components, + weights=nemotron_dclm_fineweb_10b_weights, +) + DATASET_OPTIONS = { "nemotron_cc": nemotron_cc_mixture, "dclm": dclm_mixture_config_llama3, + "nemotron_dclm_fineweb_10b": nemotron_dclm_fineweb_10b_mixture, } DEFAULT_DATASET = "nemotron_cc" diff --git a/experiments/speedrun/olmoe_m_nemotron_dclm_fineweb_40b_lr_sweep.py b/experiments/speedrun/olmoe_m_nemotron_dclm_fineweb_40b_lr_sweep.py new file mode 100644 index 0000000000..59ff8e38cf --- /dev/null +++ b/experiments/speedrun/olmoe_m_nemotron_dclm_fineweb_40b_lr_sweep.py @@ -0,0 +1,409 @@ +# Copyright 2025 The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +""" +OLMoE-M AdamC sweep on configurable Nemotron-family datasets (fixed token budget). + +This mirrors `experiments/speedrun/olmoe_s_dclm10b_moe_vs_dense_lr_sweep.py`, but: +- uses tokenized datasets from `olmoe_1b7b_nemotron_40b.py`: + - `nemotron_cc` (full Nemotron-CC) + - `nemotron_dclm_fineweb_10b` (composite Nemotron+DCLM+FineWeb mixture) +- runs OLMoE-M geometry (16 experts, top-2 routing) +- compares 5 MoE variants across 4 learning rates: + 1) `olmoe_m` (vanilla) + 2) `olmoe_m_bilinear` (bilinear expert MLPs; SwiGLU -> (W1 x) * (W3 x)) + 3) `olmoe_m_stab2` (two stability measures): + - auxiliary-free load balancing (ALF-LB) + - fp32 router compute + 4) `olmoe_m_stab3` (three stability measures): + - QK-norm + - topk-then-softmax routing + - fp32 router compute + 5) `olmoe_m_stab5` (five stability measures, including fp32 router compute): + - QK-norm + - topk-then-softmax routing + - auxiliary-free load balancing (ALF-LB) + - dense routing for first 2 blocks + - fp32 router compute + +W&B: +- choose a project with `--wandb-project` (for example `olmoe_m` or `olmoe_m_nemotron`). +- default learning-rate sweep is `[8e-4, 1e-3, 2e-3, 4e-3]`. +""" + +# nodryrun + +from __future__ import annotations + +import argparse +import dataclasses +import logging +import os +from datetime import timedelta + +from experiments.defaults import default_train +from experiments.simple_train_config import SimpleTrainConfig +from experiments.speedrun.custom_mixtral import MixtralConfig +from experiments.speedrun.olmoe_1b7b_nemotron_40b import DATASET_OPTIONS +from fray.cluster import ResourceConfig +from levanter.data.text import LMMixtureDatasetConfig +from levanter.optim import AdamConfig +from marin.execution.executor import ExecutorMainConfig, ExecutorStep, executor_main + +logger = logging.getLogger("ray") + +OLMOE_1B7B_REFERENCE_CHECKPOINT = "allenai/OLMoE-1B-7B-0125" + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def _format_lr_tag(lr: float) -> str: + s = f"{lr:.2e}" + mantissa, exp = s.split("e", 1) + mantissa = mantissa.replace(".", "p") + exp_i = int(exp) + exp_tag = f"em{abs(exp_i):02d}" if exp_i < 0 else f"e{exp_i:02d}" + return f"{mantissa}{exp_tag}" + + +def _identity_activation(x): + return x + + +def _build_olmoe_m_config(seq_len: int, *, cross_entropy_block_size: int) -> MixtralConfig: + # Keep expert granularity fixed: topk/n_experts = 2/16 = 1/8. + return MixtralConfig( + seq_len=seq_len, + hidden_dim=1024, + intermediate_dim=512, + num_layers=12, + num_heads=8, + num_kv_heads=4, + n_routed_experts=16, + num_experts_per_tok=2, + layer_norm_epsilon=1e-5, + gradient_checkpointing=True, + scan_layers=True, + use_gmm=True, + # Keep the CE vocab block size modest: very large blocks can trigger TPU/XLA allocation-size overflows for + # (tokens x vocab_block) intermediates at long seq / large batch settings. + cross_entropy_block_size=cross_entropy_block_size, + cross_entropy_implementation="xla", + flash_attention_block_size=None, + reference_checkpoint=OLMOE_1B7B_REFERENCE_CHECKPOINT, + tokenizer=OLMOE_1B7B_REFERENCE_CHECKPOINT, + ) + + +def _dataset_tag(dataset_name: str) -> str: + if dataset_name == "nemotron_dclm_fineweb_10b": + return "nemotron_dclm_fineweb" + return dataset_name + + +def _experiment_tag(dataset_name: str) -> str: + if dataset_name == "nemotron_cc": + return "exp=olmoe_m_nemotron_lr_sweep" + return "exp=olmoe_m_lr_sweep" + + +def _make_tags( + *, + variant: str, + lr: float, + seq_len: int, + global_batch_size: int, + token_target: int, + permutation_type: str, + use_qk_norm: bool, + router_topk_then_softmax: bool, + alf_lb_loss_scale: float, + dense_first_n_layers: int, + router_fp32: bool, + dataset_name: str, + extra_tags: list[str], +) -> list[str]: + return [ + _experiment_tag(dataset_name), + f"data={_dataset_tag(dataset_name)}", + f"token_target={token_target}", + f"perm={permutation_type}", + f"seq={seq_len}", + f"bs={global_batch_size}", + "opt=adamc_b0.9_0.95", + f"lr={lr:.2e}", + f"variant={variant}", + f"stab_qk_norm={int(use_qk_norm)}", + f"stab_topk_then_softmax={int(router_topk_then_softmax)}", + f"stab_alf_lb={int(alf_lb_loss_scale > 0)}", + f"stab_dense_first2={int(dense_first_n_layers >= 2)}", + f"stab_router_fp32={int(router_fp32)}", + *extra_tags, + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "OLMoE-M LR sweep on configurable Nemotron-family datasets " + "(token budget; configurable variants x 4 learning rates)." + ) + ) + parser.add_argument("--tpu-type", default="v5p-16") + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--global-batch-size", type=int, default=256) + parser.add_argument("--token-target", type=int, default=40_000_000_000) + parser.add_argument( + "--cross-entropy-block-size", + type=int, + default=2048, + help="Vocab block size for fused CE. Smaller values reduce TPU/XLA allocation pressure.", + ) + + parser.add_argument( + "--learning-rates", + type=float, + nargs="+", + default=[8e-4, 1e-3, 2e-3, 4e-3], + help="Explicit learning-rate sweep values.", + ) + parser.add_argument("--weight-decay", type=float, default=0.1) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument("--epsilon", type=float, default=1e-8) + parser.add_argument("--max-grad-norm", type=float, default=1.0) + parser.add_argument("--warmup-steps", type=float, default=2000) + parser.add_argument("--min-lr-ratio", type=float, default=0.125) + parser.add_argument( + "--steps-per-eval", + type=int, + default=5000, + help="How often to run Levanter-native validation losses during training.", + ) + parser.add_argument( + "--disable-default-validation", + action="store_true", + help="Disable default Levanter validation losses (Paloma + uncheatable).", + ) + + parser.add_argument( + "--permutation-type", + choices=("feistel", "linear"), + default="feistel", + help="Shuffle permutation type for the selected dataset.", + ) + parser.add_argument( + "--dataset", + choices=("nemotron_cc", "nemotron_dclm_fineweb_10b"), + default="nemotron_cc", + help="Tokenized dataset preset to use for training.", + ) + parser.add_argument( + "--dataset-tokenizer", + type=str, + default="stanford-crfm/marin-tokenizer", + help=( + "Optional tokenizer name/path used for vocab size / special ids. " + "Must match the tokenizer used to pretokenize the dataset." + ), + ) + + parser.add_argument( + "--single-checkpoint", + action="store_true", + help=( + "Only keep one (temporary) checkpoint at a time to reduce disk pressure. " + "This disables permanent step-based checkpoints." + ), + ) + parser.add_argument("--checkpoint-save-minutes", type=int, default=60) + + parser.add_argument("--wandb-project", type=str, default="olmoe_m") + parser.add_argument("--wandb-name-suffix", type=str, default=None) + parser.add_argument("--run-suffix", type=str, default=None) + parser.add_argument("--extra-tag", action="append", default=[], help="Additional W&B tag (repeatable).") + + parser.add_argument("--stab-alf-lb-loss-scale", type=float, default=0.01) + parser.add_argument( + "--variants", + nargs="+", + choices=("olmoe_m", "olmoe_m_bilinear", "olmoe_m_stab2", "olmoe_m_stab3", "olmoe_m_stab5"), + default=["olmoe_m", "olmoe_m_bilinear", "olmoe_m_stab3", "olmoe_m_stab5"], + help=("Which variants to run (default: " "olmoe_m olmoe_m_bilinear olmoe_m_stab3 olmoe_m_stab5)."), + ) + + # Executor controls (so this script can be run under ray_run without draccus CLI conflicts). + parser.add_argument("--prefix", default=os.getenv("MARIN_PREFIX")) + parser.add_argument("--executor-info-base-path", default=None) + parser.add_argument("--dry-run", action="store_true") + parser.add_argument( + "--max-concurrent", + type=int, + default=1, + help=( + "Maximum number of training steps to run concurrently within this sweep driver. " + "Set to 1 to use a single TPU slice per submission (sequential LR runs)." + ), + ) + parser.set_defaults(force_run_failed=True) + parser.add_argument( + "--no-force-run-failed", + dest="force_run_failed", + action="store_false", + help="If set, do not retry steps that failed previously (executor will stop on FAILED status).", + ) + parser.add_argument("--run-only", nargs="*", default=None) + args = parser.parse_args() + if args.cross_entropy_block_size <= 0: + raise ValueError("--cross-entropy-block-size must be > 0") + if not args.learning_rates: + raise ValueError("--learning-rates must include at least one value.") + + use_default_validation = not args.disable_default_validation + + tokens_per_step = args.global_batch_size * args.seq_len + num_train_steps = max(1, _ceil_div(args.token_target, tokens_per_step)) + logger.info( + "Token budget=%d, tokens/step=%d => num_train_steps=%d", args.token_target, tokens_per_step, num_train_steps + ) + + base_optimizer = AdamConfig( + learning_rate=float(args.learning_rates[0]), + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + epsilon=args.epsilon, + max_grad_norm=args.max_grad_norm, + warmup=args.warmup_steps, + min_lr_ratio=args.min_lr_ratio, + lr_schedule="cosine", + adamc_weight_decay=True, + ) + base_train_config = SimpleTrainConfig( + resources=ResourceConfig.with_tpu(tpu_type=args.tpu_type), + train_batch_size=args.global_batch_size, + num_train_steps=num_train_steps, + learning_rate=float(args.learning_rates[0]), + train_seq_len=args.seq_len, + optimizer_config=base_optimizer, + steps_per_eval=args.steps_per_eval, + steps_per_export=100_000_000, + steps_per_hf_export=-1, + ) + + tokenized = DATASET_OPTIONS[args.dataset] + if not isinstance(tokenized, LMMixtureDatasetConfig): + raise ValueError(f"Expected {args.dataset} to resolve to a mixture dataset config.") + tokenized = dataclasses.replace( + tokenized, + permutation_type=args.permutation_type, + tokenizer=args.dataset_tokenizer, + ) + + olmoe_m = _build_olmoe_m_config(args.seq_len, cross_entropy_block_size=args.cross_entropy_block_size) + olmoe_m_bilinear = dataclasses.replace(olmoe_m, activation_function=_identity_activation) + olmoe_m_stab2 = dataclasses.replace( + olmoe_m, + router_fp32=True, + alf_lb_loss_scale=args.stab_alf_lb_loss_scale, + ) + olmoe_m_stab3 = dataclasses.replace( + olmoe_m, + use_qk_norm=True, + router_topk_then_softmax=True, + router_fp32=True, + ) + olmoe_m_stab5 = dataclasses.replace( + olmoe_m, + use_qk_norm=True, + router_topk_then_softmax=True, + router_fp32=True, + alf_lb_loss_scale=args.stab_alf_lb_loss_scale, + dense_first_n_layers=2, + ) + + variants: list[tuple[str, MixtralConfig]] = [ + ("olmoe_m", olmoe_m), + ("olmoe_m_bilinear", olmoe_m_bilinear), + ("olmoe_m_stab2", olmoe_m_stab2), + ("olmoe_m_stab3", olmoe_m_stab3), + ("olmoe_m_stab5", olmoe_m_stab5), + ] + selected_variants = {v.strip() for v in args.variants} + variants = [v for v in variants if v[0] in selected_variants] + + steps: list[ExecutorStep] = [] + for lr in args.learning_rates: + lr = float(lr) + lr_tag = _format_lr_tag(lr) + optimizer_cfg = dataclasses.replace(base_optimizer, learning_rate=lr) + train_cfg = dataclasses.replace(base_train_config, learning_rate=lr, optimizer_config=optimizer_cfg) + + for variant, model_cfg in variants: + base_name = f"olmoe_m_40b/{variant}/lr_{lr_tag}/s{args.seq_len}_b{args.global_batch_size}" + + suffix = f"_{args.wandb_name_suffix}" if args.wandb_name_suffix else "" + wandb_name = f"olmoe_m_{variant}_s{args.seq_len}_b{args.global_batch_size}_lr{lr_tag}{suffix}" + + use_qk_norm = bool(getattr(model_cfg, "use_qk_norm", False)) + router_topk_then_softmax = bool(getattr(model_cfg, "router_topk_then_softmax", False)) + alf_lb_loss_scale = float(getattr(model_cfg, "alf_lb_loss_scale", 0.0) or 0.0) + dense_first_n_layers = int(getattr(model_cfg, "dense_first_n_layers", 0) or 0) + router_fp32 = bool(getattr(model_cfg, "router_fp32", False)) + + extra_tags = list(args.extra_tag) + if args.run_suffix: + extra_tags.append(f"run_suffix={args.run_suffix}") + + tags = _make_tags( + variant=variant, + lr=lr, + seq_len=args.seq_len, + global_batch_size=args.global_batch_size, + token_target=args.token_target, + permutation_type=args.permutation_type, + use_qk_norm=use_qk_norm, + router_topk_then_softmax=router_topk_then_softmax, + alf_lb_loss_scale=alf_lb_loss_scale, + dense_first_n_layers=dense_first_n_layers, + router_fp32=router_fp32, + dataset_name=args.dataset, + extra_tags=extra_tags, + ) + + run_suffix = f"_{args.run_suffix}" if args.run_suffix else "" + steps.append( + default_train( + name=f"{base_name}{run_suffix}", + tokenized=tokenized, + model_config=model_cfg, + train_config=train_cfg, + tags=tags, + use_default_validation=use_default_validation, + eval_harness_tasks=(), + wandb_name=wandb_name, + wandb_project=args.wandb_project, + checkpointer_save_interval=timedelta(minutes=int(args.checkpoint_save_minutes)), + checkpointer_keep=[] if args.single_checkpoint else None, + ) + ) + + executor_cfg = ExecutorMainConfig( + prefix=args.prefix, + executor_info_base_path=args.executor_info_base_path, + dry_run=args.dry_run, + force_run_failed=args.force_run_failed, + run_only=args.run_only, + max_concurrent=args.max_concurrent, + ) + executor_main.__wrapped__( + executor_cfg, + steps=steps, + description="OLMoE-M AdamC LR sweep (Nemotron-family datasets; 40B tokens)", + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/prebuilt_caches.py b/experiments/speedrun/prebuilt_caches.py index 26e1864514..c1a5eb77f8 100644 --- a/experiments/speedrun/prebuilt_caches.py +++ b/experiments/speedrun/prebuilt_caches.py @@ -22,11 +22,11 @@ """ -from experiments.marin_models import marin_tokenizer +from experiments.llama import llama3_tokenizer from marin.processing.tokenize.download_pretokenized import download_pretokenized_cache fineweb_edu_10B_repo_id = "marin-community/fineweb-edu-pretokenized-10B" -fineweb_edu_subcache_10B = download_pretokenized_cache("fineweb-edu-10B", fineweb_edu_10B_repo_id, marin_tokenizer) +fineweb_edu_subcache_10B = download_pretokenized_cache("fineweb-edu-10B", fineweb_edu_10B_repo_id, llama3_tokenizer) fineweb_edu_10M_repo_id = "marin-community/fineweb-edu-pretokenized-10M" -fineweb_edu_subcache_10M = download_pretokenized_cache("fineweb-edu-10M", fineweb_edu_10M_repo_id, marin_tokenizer) +fineweb_edu_subcache_10M = download_pretokenized_cache("fineweb-edu-10M", fineweb_edu_10M_repo_id, llama3_tokenizer) diff --git a/lib/fray/src/fray/v1/cluster/ray/deps.py b/lib/fray/src/fray/v1/cluster/ray/deps.py index 6d37673ca2..242cd7997e 100644 --- a/lib/fray/src/fray/v1/cluster/ray/deps.py +++ b/lib/fray/src/fray/v1/cluster/ray/deps.py @@ -13,6 +13,8 @@ from ray.runtime_env import RuntimeEnv +import ray + logger = logging.getLogger("ray") # Packages to ignore when computing the runtime environment. @@ -163,8 +165,9 @@ def build_runtime_env_for_packages( package_spec = compute_frozen_packages(extra) requirements_txt = [ - """ + f""" # Generated by fray/cluster/ray/deps.py +# base_ray_version: {ray.__version__} """ ] # Add resiliparse custom index diff --git a/lib/fray/src/fray/v2/ray_backend/deps.py b/lib/fray/src/fray/v2/ray_backend/deps.py index bf38077d9b..2dfe3a1eb1 100644 --- a/lib/fray/src/fray/v2/ray_backend/deps.py +++ b/lib/fray/src/fray/v2/ray_backend/deps.py @@ -13,6 +13,8 @@ from ray.runtime_env import RuntimeEnv +import ray + logger = logging.getLogger("ray") # Packages to ignore when computing the runtime environment. @@ -162,8 +164,9 @@ def build_runtime_env_for_packages( package_spec = compute_frozen_packages(extra) requirements_txt = [ - """ + f""" # Generated by fray/v2/ray/deps.py +# base_ray_version: {ray.__version__} """ ] # Add resiliparse custom index diff --git a/lib/haliax/src/haliax/nn/normalization.py b/lib/haliax/src/haliax/nn/normalization.py index 5385b2635a..64055a9bd9 100644 --- a/lib/haliax/src/haliax/nn/normalization.py +++ b/lib/haliax/src/haliax/nn/normalization.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 + import dataclasses from abc import abstractmethod from typing import TypeVar @@ -105,18 +106,18 @@ class LayerNorm(LayerNormBase): dtype: jnp.dtype | None = eqx.field(default=None, static=True) def __call__(self, x: NamedArray) -> NamedArray: - dtype = x.dtype + in_dtype = x.dtype mean = x.mean(self.axis) var = x.var(self.axis) inv = hax.rsqrt(var + self.eps) out = (x - mean) * inv - out = out.astype(dtype) + out = out.astype(jnp.float32) if self.weight is not None: - out = self.weight.astype(out.dtype) * out + out = hax.auto_sharded(self.weight).astype(jnp.float32) * out if self.bias is not None: - out = out + self.bias.astype(out.dtype) - return out + out = out + hax.auto_sharded(self.bias).astype(jnp.float32) + return out.astype(in_dtype) class RmsNorm(LayerNormBase): @@ -130,13 +131,13 @@ def __call__(self, x: NamedArray) -> NamedArray: var = hax.mean(hax.square(x), axis=self.axis) inv = hax.rsqrt(var + self.eps) out = x * inv - out = out.astype(in_dtype) + out = out.astype(jnp.float32) if self.weight is not None: - out = self.weight.astype(out.dtype) * out + out = hax.auto_sharded(self.weight).astype(jnp.float32) * out if self.bias is not None: - out = out + self.bias.astype(out.dtype) - return out + out = out + hax.auto_sharded(self.bias).astype(jnp.float32) + return out.astype(in_dtype) def logsumexp(a: A, axis: AxisSelection | None = None) -> A: diff --git a/lib/levanter/docs/Performance-Guide.md b/lib/levanter/docs/Performance-Guide.md index 147ec912e8..337f38e0bb 100644 --- a/lib/levanter/docs/Performance-Guide.md +++ b/lib/levanter/docs/Performance-Guide.md @@ -102,3 +102,60 @@ but with some patience and work you can back those out by looking at the next se * `jvp(OP)` means the forward pass. (JVP stands for Jacobian-vector product.) * `transpose(jvp(OP))` means the backward pass. * `remat` (short for rematerialization) means that the operation is recomputed in the backward pass, i.e. gradient checkpointing. + + +## Mixtral (MoE) throughput on TPU + +### High-MFU sharding settings + +We have observed a large MFU gap for Mixtral-style MoE models on v5p when using an explicit mesh +(`AxisType.Explicit`) versus the default Auto/Manual mesh. + +Practical guidance: + +- Keep `trainer.use_explicit_mesh_axes=false` for standard Levanter/Haliax models (Mixtral/OLMoE/etc.) unless you are + calling `jax.sharding.reshard(...)` with named `PartitionSpec`s. +- Models that initialize weights via `jax.sharding.reshard` (for example, Grugformer) still need + `trainer.use_explicit_mesh_axes=true`. + +In `experiments/speedrun/mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b.py`, the relevant flags are: + +Example command (Mixtral-only, v5p-32, seq=4096, global_bs=384, pdp=8): + +```bash +GROUP="mixtral_vs_dense_40b_v5p32_$(git rev-parse --short HEAD)" && \ +SUFFIX="t$(date +%Y%m%d_%H%M%S)_$$" && \ +TMPDIR=/tmp RAY_TMPDIR=/tmp uv run python -m marin.run.ray_run \ + --cluster infra/marin-us-central1.yaml --extra tpu --no_wait \ + --env_vars WANDB_MODE online \ + --env_vars WANDB_PROJECT mixtral_vs_dense \ + --env_vars WANDB_GROUP "$GROUP" \ + -- python -m experiments.speedrun.mixtral_vs_dense_nemotron_dclm_fineweb_edu_100b \ + --models mixtral_8x7b \ + --tpu-type v5p-32 \ + --seq-len 4096 \ + --global-batch-size 384 \ + --per-device-parallelism 8 \ + --token-target 40000000000 \ + --no-explicit-mesh-axes \ + --mixtral-use-gmm \ + --mixtral-cross-entropy-block-size 1024 \ + --run-suffix "mixtral_base_${SUFFIX}" +``` + + +- `--no-explicit-mesh-axes` (default): high MFU +- `--explicit-mesh-axes`: can reduce MFU substantially +- `--legacy-axis-resources`: optional mapping that matches December-era `(replica, data)` sharding for + `batch/token/token_repeat` + +### Profiling artifacts + +Levanter writes jaxpr/HLO dumps under `logs//artifacts/` when `trainer.log_jaxprs` / `trainer.log_xla_hlo` are +enabled (they default to true): + +- `train_step.jaxpr.txt.gz` +- `train_step.hlo.txt` + +Perfetto traces live under `logs//profiler/plugins/profile/.../perfetto_trace.json.gz` and are uploaded to W&B +as the `jax_profile` artifact. diff --git a/lib/levanter/src/levanter/grad_accum.py b/lib/levanter/src/levanter/grad_accum.py index bd60e9361b..21fe73a9f9 100644 --- a/lib/levanter/src/levanter/grad_accum.py +++ b/lib/levanter/src/levanter/grad_accum.py @@ -13,7 +13,7 @@ from haliax.partitioning import ResourceAxis from haliax.util import is_named_array from jax.lax import with_sharding_constraint -from jax.sharding import PartitionSpec +from jax.sharding import NamedSharding, PartitionSpec from levanter.metrics import Metric from levanter.metrics import fold as fold_metric @@ -171,14 +171,35 @@ def loop(acc, microbatch_and_key): def _reshape_for_microbatch(Batch: Axis, Microbatch: Axis, AccumStep: Axis, inputs, axis_mapping): + def _reshape_with_out_sharding(array, *, batch_dim: int): + new_shape = array.shape[:batch_dim] + (AccumStep.size, Microbatch.size) + array.shape[batch_dim + 1 :] + sharding = getattr(array, "sharding", None) + if sharding is None: + aval = getattr(array, "aval", None) + sharding = getattr(aval, "sharding", None) if aval is not None else None + + if isinstance(sharding, NamedSharding): + in_spec = tuple(sharding.spec) + out_spec = (*in_spec[:batch_dim], None, in_spec[batch_dim], *in_spec[batch_dim + 1 :]) + out_sharding = NamedSharding(sharding.mesh, PartitionSpec(*out_spec)) + return jax.lax.reshape(array, new_sizes=new_shape, out_sharding=out_sharding) + + return array.reshape(new_shape) + def _reshape(x): if isinstance(x, hax.NamedArray): - if not x.has_axis(Batch.name): + batch_dim = x.axis_indices(Batch) + if batch_dim is None: return x - x = x.unflatten_axis(Batch, (AccumStep, Microbatch)) + new_array = _reshape_with_out_sharding(x.array, batch_dim=batch_dim) + new_axes = list(x.axes) + new_axes[batch_dim : batch_dim + 1] = [AccumStep, Microbatch] + x = hax.NamedArray(new_array, tuple(new_axes)) return hax.shard(x, axis_mapping) elif isinstance(x, jnp.ndarray): - x = x.reshape((AccumStep.size, Microbatch.size) + x.shape[1:]) + if not x.shape or x.shape[0] != Batch.size: + return x + x = _reshape_with_out_sharding(x, batch_dim=0) return with_sharding_constraint(x, PartitionSpec(None, ResourceAxis.DATA, *(None,) * (len(x.shape) - 2))) else: # assert jnp.isscalar(x) diff --git a/lib/levanter/src/levanter/grug/sharding.py b/lib/levanter/src/levanter/grug/sharding.py index 8d57df5bb8..df5bf209d0 100644 --- a/lib/levanter/src/levanter/grug/sharding.py +++ b/lib/levanter/src/levanter/grug/sharding.py @@ -11,5 +11,14 @@ Pvocab = P(None, None) +def Pbatch_moe() -> P: + """PartitionSpec for batch/token axes in MoE experiments. + + Shards the leading (batch or token) dimension over (`replica`, `data`) so it matches + the legacy axis_resources mapping used by high-MFU MoE runs. + """ + return P(("replica", "data")) + + def unshard(x: jax.Array) -> jax.Array: return reshard(x, P((None,))) diff --git a/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_tpu.py b/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_tpu.py index 8ff1d61091..6c98fe8a19 100644 --- a/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_tpu.py +++ b/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_tpu.py @@ -28,6 +28,12 @@ class PallasUnsupportedError(NotImplementedError): NUM_LANES = 128 +# TPU VMEM capacity is fixed-size on current generations (e.g. v5p). The fused CE +# forward kernel materializes an intermediate `xw_tiled` buffer in VMEM with dtype +# float32 and shape [b_block_size, v_block_size]. If this buffer cannot fit, the +# kernel will fail with `RESOURCE_EXHAUSTED` during compilation. +_TPU_VMEM_BYTES = 64 * 1024 * 1024 # 64MiB + def _fwd_cost_estimate( x: jax.Array, @@ -222,6 +228,20 @@ def _validate_inputs( f"got b_block_size={block_sizes.b_block_size}." ) + # VMEM scratch check: the kernel accumulates logits into a float32 [B,V] tile. + b_block_size = int(block_sizes.b_block_size) + v_block_size = int(block_sizes.v_block_size) + if b_block_size > 0 and v_block_size > 0: + scratch_bytes = b_block_size * v_block_size * 4 # float32 + if scratch_bytes > _TPU_VMEM_BYTES: + raise PallasUnsupportedError( + "Pallas fused CE requires a float32 VMEM scratch tile of size " + f"[{b_block_size}, {v_block_size}] (~{scratch_bytes / (1024 * 1024):.1f}MiB), which exceeds " + f"TPU VMEM capacity ({_TPU_VMEM_BYTES / (1024 * 1024):.1f}MiB). " + "Use a smaller v_block_size (e.g. <= 16384 when b_block_size=1024), or select " + "`implementation='xla'`." + ) + def _infer_num_tensorcores() -> int: if jax.default_backend() != "tpu": diff --git a/lib/levanter/src/levanter/utils/jax_utils.py b/lib/levanter/src/levanter/utils/jax_utils.py index d90d1ed8f5..b50fc1fb97 100644 --- a/lib/levanter/src/levanter/utils/jax_utils.py +++ b/lib/levanter/src/levanter/utils/jax_utils.py @@ -15,12 +15,14 @@ import jax import numpy as np from haliax import is_named_array -from haliax._src.util import index_where from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceAxis, ResourceMapping from jax import numpy as jnp +from jax._src import mesh as jax_mesh from jax._src.mesh import get_concrete_mesh +from jax.experimental.multihost_utils import broadcast_one_to_all as _jax_broadcast_one_to_all from jax.experimental.multihost_utils import host_local_array_to_global_array +from jax.interpreters.pxla import thread_resources as pxla_thread_resources from jax.sharding import Mesh, NamedSharding, PartitionSpec from jaxtyping import PRNGKeyArray, PyTree @@ -399,47 +401,90 @@ def broadcast_shard(x: T, out_axis_specs: Any, source: int = 0) -> T: 2. Then, inside jit, we select the source'th element of the array, then reshard with the out_axis_specs """ - current_mesh: jax.sharding.Mesh = hax.partitioning._get_mesh() + # NOTE: Prior implementations attempted to use `jax.make_array_from_callback` with a `NamedSharding` constructed + # from the active training mesh. On multi-host TPU, that can crash with "not fully addressable" / "not addressable + # sharding" errors. + # + # We broadcast using a temporary mesh over ("processes", "local_devices") and a reduction across the sharded + # "processes" axis. Some callers (notably eval-harness control-plane messages) use `PartitionSpec()` (replicated) + # outputs; in that case we intentionally return host-local arrays to avoid device-order mismatches between the + # temporary mesh and the active training mesh. + + def _maybe_constrain(arr: jax.Array, spec: Any) -> jax.Array: + # `jax.jit` (and some multihost contexts) can error if we try to apply a sharding constraint with a sharding + # that isn't fully addressable from the current host. In practice, callers like eval-harness pass + # `NamedSharding` objects produced from a global mesh; for robustness we avoid constraining in that case and + # let downstream `named_jit` / pjitted functions reshard as needed. + if spec is None: + return arr + if isinstance(spec, PartitionSpec): + resolved_mesh = haliax.partitioning._get_mesh() + return jax.lax.with_sharding_constraint(arr, NamedSharding(resolved_mesh, spec)) + return arr + + def _replicated_out_specs(specs: Any) -> bool: + if specs is None: + return True + if isinstance(specs, PartitionSpec): + return specs == PartitionSpec() + leaves = jax.tree_util.tree_leaves(specs, is_leaf=lambda s: isinstance(s, PartitionSpec)) + if not leaves: + return False + for leaf in leaves: + if leaf is None: + continue + if not isinstance(leaf, PartitionSpec) or leaf != PartitionSpec(): + return False + return True - axis_names = current_mesh.axis_names + if jax.process_count() == 1: - valid_device_for_process = index_where(lambda d: d.host_id == source, current_mesh.devices.flatten()) - sharding = NamedSharding( - current_mesh, - PartitionSpec( - axis_names, - ), - ) + def in_jit_single(x_leaf: Any, spec: Any) -> Any: + arr = x_leaf.array if isinstance(x_leaf, hax.NamedArray) else x_leaf + arr = _maybe_constrain(arr, spec) + return hax.named(arr, x_leaf.axis_names) if isinstance(x_leaf, hax.NamedArray) else arr + + return eqx.filter_jit(jax.tree.map)(in_jit_single, x, out_axis_specs, is_leaf=is_named_array) + + active_mesh = haliax.partitioning._get_mesh() + devices: np.ndarray = np.asarray(active_mesh.devices).reshape(jax.process_count(), jax.local_device_count()) + global_mesh = Mesh(devices, ("processes", "local_devices")) + in_pspec = PartitionSpec("processes") + + def pre_jit(x_leaf: Any) -> jax.Array: + arr = x_leaf.array if isinstance(x_leaf, hax.NamedArray) else x_leaf + # `host_local_array_to_global_array` expects the leading axis to be shardable by the requested mesh axis. + # In particular, for `PartitionSpec("processes")` it requires the leading dimension to be divisible by + # `jax.process_count()`. Construct a leading "processes" axis explicitly. + payload = np.asarray(arr) if jax.process_index() == source else np.zeros(arr.shape, dtype=arr.dtype) - def pre_jit(x): + host = np.zeros((jax.process_count(),) + payload.shape, dtype=payload.dtype) if jax.process_index() == source: - inp = np.array(x) - else: - inp = jnp.zeros(x.shape, dtype=x.dtype) + host[source] = payload - shape = (len(jax.devices()),) + inp.shape - inp = jnp.expand_dims(inp, axis=0) - out = jax.make_array_from_callback(shape, sharding, lambda _: inp) + return host_local_array_to_global_array(host, global_mesh, in_pspec) - return out + if _replicated_out_specs(out_axis_specs): + # Broadcast to all hosts, then materialize host-local arrays. This avoids JAX errors when the active training + # mesh uses a device order that differs from `jax.devices()` order (common on multi-host TPU). + x_global = jax.tree.map(pre_jit, x, is_leaf=is_named_array) + with haliax.partitioning.set_mesh(global_mesh): + reduced = jax.jit(_psum, out_shardings=NamedSharding(global_mesh, PartitionSpec()))(x_global) - def in_jit(x, pspec): - if isinstance(x, hax.NamedArray): - arr = x.array - else: - arr = x - arr = jax.lax.with_sharding_constraint(arr[valid_device_for_process], pspec) + def post_jit(x_leaf: Any, x_leaf_orig: Any) -> Any: + host_arr = jax.device_get(x_leaf.addressable_data(0)) + arr = jnp.asarray(host_arr) + return hax.named(arr, x_leaf_orig.axis_names) if isinstance(x_leaf_orig, hax.NamedArray) else arr - if isinstance(x, hax.NamedArray): - return hax.named(arr, x.axis_names) - else: - return arr + return jax.tree.map(post_jit, reduced, x, is_leaf=is_named_array) - x = jax.tree.map(pre_jit, x) - # q = eqx.filter_jit(jax.tree.map).lower(in_jit, x, out_axis_specs, is_leaf=is_named_array).as_text() - out = eqx.filter_jit(jax.tree.map)(in_jit, x, out_axis_specs, is_leaf=is_named_array) + def in_jit(x_global: jax.Array, spec: Any, x_leaf: Any) -> Any: + arr = jnp.sum(x_global, axis=0) + arr = _maybe_constrain(arr, spec) + return hax.named(arr, x_leaf.axis_names) if isinstance(x_leaf, hax.NamedArray) else arr - return out + x_global = jax.tree.map(pre_jit, x, is_leaf=is_named_array) + return eqx.filter_jit(jax.tree.map)(in_jit, x_global, out_axis_specs, x, is_leaf=is_named_array) def tree_broadcast_to(prefix: PyTree[L], t: T, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> T: @@ -456,14 +501,50 @@ def tree_broadcast_to(prefix: PyTree[L], t: T, *, is_leaf: Optional[Callable[[An ) -# Non-busted version of broadcast_one_to_all from jax.multihost_utils. (The issue is that if you use a non-contiguous -# mesh, their utility blows up because it makes a contiguous mesh.) +# Mesh-safe broadcast helper. +# +# `jax.experimental.multihost_utils` implementations rely on `jax.jit` + sharding, which can error if they're invoked +# under a non-trivial `pjit` mesh context (e.g. device order differs from `jax.local_devices()`). We keep this utility +# in Levanter because it's used during initialization/barriers and needs to be robust even when a mesh context is +# already active. def _psum(xs: Any) -> Any: return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs) +@contextlib.contextmanager +def _without_mesh() -> Any: + """Temporarily clear the active `pjit` mesh context. + + Some JAX multi-host utilities (and `pmap` in newer JAX versions, which may be implemented via + `shard_map`) can fail if called while a non-trivial mesh context is active. This helper ensures + we can run simple host-control collectives (e.g. checkpoint save coordination) safely even when + the training loop has a mesh set. + """ + old_jax_mesh_env = jax_mesh.thread_resources.env + old_pxla_mesh = pxla_thread_resources.env.physical_mesh + + cleared_jax_mesh = bool(old_jax_mesh_env.physical_mesh.axis_names) + cleared_pxla_mesh = old_pxla_mesh is not None and not old_pxla_mesh.empty + + if cleared_jax_mesh: + jax_mesh.thread_resources.env = jax_mesh.EMPTY_ENV + if cleared_pxla_mesh: + pxla_thread_resources.env.physical_mesh = None + + if cleared_jax_mesh or cleared_pxla_mesh: + try: + yield + finally: + if cleared_pxla_mesh: + pxla_thread_resources.env.physical_mesh = old_pxla_mesh + if cleared_jax_mesh: + jax_mesh.thread_resources.env = old_jax_mesh_env + else: + yield + + def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: """Broadcast data from a source host (host 0 by default) to all other hosts. @@ -481,28 +562,11 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: if jax.process_count() == 1: return jax.tree.map(np.asarray, in_tree) - if is_source is None: - is_source = jax.process_index() == 0 - - devices: np.ndarray = np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()) - global_mesh = jax.sharding.Mesh(devices, ("processes", "local_devices")) - pspec = PartitionSpec("processes") - - def pre_jit(x): - if is_source: - inp = x - else: - inp = np.zeros_like(x) - inp = np.expand_dims(inp, axis=0) - return host_local_array_to_global_array(inp, global_mesh, pspec) - - def post_jit(x): - return jax.device_get(x.addressable_data(0)) - - with haliax.partitioning.set_mesh(global_mesh): - in_tree = jax.tree.map(pre_jit, in_tree) - out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(global_mesh, PartitionSpec()))(in_tree) - return jax.tree.map(post_jit, out_tree) + # JAX's upstream implementation creates a temporary `jax.set_mesh` context. That can fail if we're invoked while + # the training loop has a different mesh active. We temporarily clear the mesh context so we can safely run this + # host-control collective (used by checkpointing/eval harness coordination). + with _without_mesh(): + return _jax_broadcast_one_to_all(in_tree, is_source=is_source) def assert_equal(in_tree, fail_message: str = ""): From bd4b91fb2db8122b8a29f580d6c42c3a1e8178fb Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 26 Feb 2026 20:58:30 -0800 Subject: [PATCH 2/7] grugformer_moe: avoid v4 auto-axis sharding failures --- .../speedrun/grugformer_moe/grugformer_moe.py | 116 +++++++++++++----- 1 file changed, 86 insertions(+), 30 deletions(-) diff --git a/experiments/speedrun/grugformer_moe/grugformer_moe.py b/experiments/speedrun/grugformer_moe/grugformer_moe.py index a58bae86b0..ac000114dd 100644 --- a/experiments/speedrun/grugformer_moe/grugformer_moe.py +++ b/experiments/speedrun/grugformer_moe/grugformer_moe.py @@ -9,8 +9,8 @@ Design goals: - "Grug simple": explicit tensor shapes, minimal abstractions. -- "Vanilla custom_mixtral logic": top-k routing + sort/permute dispatch + GMM (Megablox) or ragged_dot_general expert MLP, - with load-balancing loss and router z-loss. +- "Vanilla custom_mixtral logic": top-k routing + sort/permute dispatch + GMM (Megablox) or + ragged_dot_general expert MLP, with load-balancing loss and router z-loss. - Replicated experts (no expert-parallel all-to-all). """ @@ -31,14 +31,15 @@ from fray.cluster import ResourceConfig from haliax import Axis, NamedArray from haliax.nn.linear import gmm_sharded -from haliax.partitioning import _get_mesh +from haliax.partitioning import ResourceAxis, _get_mesh +from jax.lax import with_sharding_constraint from jax.experimental.shard_map import shard_map -from jax.sharding import NamedSharding, PartitionSpec as P, reshard +from jax.sharding import NamedSharding, PartitionSpec as P from jax.tree_util import register_dataclass from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree from levanter.grug.attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention -from levanter.grug.sharding import Pbatch_moe, Pvocab, unshard +from levanter.grug.sharding import Pbatch_moe from levanter.layers.attention import AttentionMask as LevanterAttentionMask from levanter.models.loss import maybe_fused_next_token_loss from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel @@ -77,15 +78,47 @@ def _print(_: None) -> jax.Array: return jax.lax.cond(is_finite, lambda _: x, _print, operand=None) + # Ruff/Pyflakes treats string literals in annotations as forward refs and checks that bare # identifiers (e.g. "D") are defined. These are jaxtyping dimension labels, not runtime symbols. D = TypeVar("D") def _pbatch() -> P: + """PartitionSpec for leading token/batch axes under the active Levanter mesh mapping. + + Grugformer MoE uses `shard_map` with explicit in/out specs. If these specs disagree with the + trainer's `MeshConfig` compute mapping (e.g. when `batch/token` include `replica_dcn`), XLA + inserts many tiny reshard/permute collectives that tank MFU. + + Use the current Haliax axis mapping when available so the specs always match the trainer config. + """ + mapping = hax.partitioning.current_thread_local_mapping() + if mapping is not None: + for logical in ("token_repeat", "token", "batch"): + physical = mapping.get(logical) + if physical is None: + continue + if isinstance(physical, list): + physical = tuple(physical) + return P(physical) + return Pbatch_moe() +def _pbatch_for_get() -> None: + """`.get(out_sharding=...)` portability guard. + + On some TPU meshes (including v4 layouts used in shared central2), logical + axes like `data` are Auto/Manual under the hood, and JAX rejects them in + indexed `.get(out_sharding=...)`. + + Let JAX choose output sharding for indexed embedding gather and constrain + downstream compute via explicit `out_sharding`/`shard_map` sites instead. + """ + return None + + #### Conventions # # Mesh meanings: @@ -195,9 +228,13 @@ def init_parameters(cfg: GrugMoeModelConfig, *, key: PRNGKeyArray) -> GrugMoePar key, embed_key, out_key = jax.random.split(key, 3) layer_keys = jax.random.split(key, cfg.num_layers) - token_embed = reshard(_init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), Pvocab) - output_proj = reshard(_init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std), Pvocab) - final_norm = reshard(jnp.ones((cfg.hidden_dim,), dtype=jnp.float32), P(None)) + token_embed = with_sharding_constraint( + _init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), P(None, ResourceAxis.DATA) + ) + output_proj = with_sharding_constraint( + _init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std), P(ResourceAxis.DATA, None) + ) + final_norm = with_sharding_constraint(jnp.ones((cfg.hidden_dim,), dtype=jnp.float32), P(None)) blocks: list[GrugMoeBlockParams] = [] # extract shape sizes for brevity and consistency @@ -219,37 +256,44 @@ def init_parameters(cfg: GrugMoeModelConfig, *, key: PRNGKeyArray) -> GrugMoePar ) = jax.random.split(layer_keys[i], 8) attn = GrugAttentionParams( - w_q=reshard(_init_weight(k_q, (hidden_dim, num_heads * head_dim), cfg.initializer_std), P("data", "model")), - w_k=reshard( + w_q=with_sharding_constraint( + _init_weight(k_q, (hidden_dim, num_heads * head_dim), cfg.initializer_std), P("data", "model") + ), + w_k=with_sharding_constraint( _init_weight(k_k, (hidden_dim, num_kv_heads * head_dim), cfg.initializer_std), P("data", "model") ), - w_v=reshard( + w_v=with_sharding_constraint( _init_weight(k_v, (hidden_dim, num_kv_heads * head_dim), cfg.initializer_std), P("data", "model") ), - w_o=reshard(_init_weight(k_o, (num_heads * head_dim, hidden_dim), cfg.initializer_std), P("model", "data")), + w_o=with_sharding_constraint( + _init_weight(k_o, (num_heads * head_dim, hidden_dim), cfg.initializer_std), P("model", "data") + ), ) # Router maps D -> E. Keep the expert axis replicated (no expert-parallel sharding). - router_w = reshard(_init_weight(k_router, (hidden_dim, num_experts), cfg.initializer_std), P("data", None)) + router_w = with_sharding_constraint( + _init_weight(k_router, (hidden_dim, num_experts), cfg.initializer_std), P("data", None) + ) - # Expert weights are replicated over the data axis and sharded over the model axis (TP). - # This keeps the GMM pathway simple and avoids ragged-dot auto-sharding pathologies. - w1 = reshard( + # Expert weights keep the expert axis replicated (no expert-parallel all-to-all), but shard + # the dense dims over (data, model) to match Levanter's ZeRO-ish param sharding. This avoids + # TPU init OOMs for OLMoE-scale models (7B total params) where optimizer state dominates memory. + w1 = with_sharding_constraint( _init_weight(k_w1, (num_experts, hidden_dim, intermediate_dim), cfg.initializer_std), - P(None, None, "model"), + P(None, ResourceAxis.DATA, ResourceAxis.MODEL), ) - w3 = reshard( + w3 = with_sharding_constraint( _init_weight(k_w3, (num_experts, hidden_dim, intermediate_dim), cfg.initializer_std), - P(None, None, "model"), + P(None, ResourceAxis.DATA, ResourceAxis.MODEL), ) - w2 = reshard( + w2 = with_sharding_constraint( _init_weight(k_w2, (num_experts, intermediate_dim, hidden_dim), cfg.initializer_std), - P(None, "model", None), + P(None, ResourceAxis.MODEL, ResourceAxis.DATA), ) # keep rms replicated - rms_attn = jnp.ones((hidden_dim,), dtype=jnp.float32) - rms_mlp = jnp.ones((hidden_dim,), dtype=jnp.float32) + rms_attn = with_sharding_constraint(jnp.ones((hidden_dim,), dtype=jnp.float32), P(None)) + rms_mlp = with_sharding_constraint(jnp.ones((hidden_dim,), dtype=jnp.float32), P(None)) blocks.append( GrugMoeBlockParams( @@ -272,7 +316,6 @@ def init_parameters(cfg: GrugMoeModelConfig, *, key: PRNGKeyArray) -> GrugMoePar def rms_norm(x: Float[Array, "... D"], weight: Float[Array, "D"], eps: float) -> Float[Array, "... D"]: - weight = unshard(weight) # Levanter runs with mixed precision (bf16 compute, fp32 params) + strict dtype promotion. # Do RMSNorm math in fp32, then cast back to the input dtype. out_dtype = x.dtype @@ -357,11 +400,15 @@ def _gmm_moe_linear( """ mesh = _get_mesh() if mesh is not None and not getattr(mesh, "empty", False): + # When `ar=True` the contracting dimension is sharded over the `model` axis and `gmm_sharded` + # will `psum` partial results. shard_map must see *local* shards on that contracting dimension, + # otherwise the Megablox GMM custom_vjp backward rule will observe mismatched primal/grad shapes. + x_spec = _pbatch() if not ar else P(_pbatch()[0], ResourceAxis.MODEL) out_specs = P(_pbatch()[0], out_axis) gmm_fn = shard_map( lambda lhs, rhs, gs: gmm_sharded(lhs, rhs, gs, ar=ar), mesh=mesh, - in_specs=(_pbatch(), w_spec, P(None)), + in_specs=(x_spec, w_spec, P(None)), out_specs=out_specs, check_rep=False, ) @@ -370,7 +417,6 @@ def _gmm_moe_linear( return gmm_sharded(x, w, group_sizes, ar=ar) - def _route( selection_logits: jax.Array, router_logits: jax.Array, @@ -602,6 +648,7 @@ def _transformer_hidden( seq_len = token_ids.shape[1] if _DEBUG_FINITE and jax.process_index() == 0: bad = jnp.any((token_ids < 0) | (token_ids >= cfg.vocab_size)) + def _print_tok(_: None) -> jax.Array: jax.debug.print( "BAD TOKENS: min={minv} max={maxv} vocab={vocab}", @@ -610,12 +657,17 @@ def _print_tok(_: None) -> jax.Array: vocab=cfg.vocab_size, ) return token_ids + token_ids = jax.lax.cond(bad, _print_tok, lambda _: token_ids, operand=None) if mask is None: mask = AttentionMask.causal() - hidden = params.token_embed.at[token_ids].get(out_sharding=_pbatch()) # [B, S, D] + token_embed_out_sharding = _pbatch_for_get() + if token_embed_out_sharding is None: + hidden = params.token_embed.at[token_ids].get() # [B, S, D] + else: + hidden = params.token_embed.at[token_ids].get(out_sharding=token_embed_out_sharding) # [B, S, D] hidden = _maybe_log_nonfinite(hidden, name="hidden/embed") aux_total = jnp.array(0.0, dtype=jnp.float32) @@ -629,7 +681,7 @@ def _print_tok(_: None) -> jax.Array: attn_out = attention(q, k, v, mask) attn_out = _maybe_log_nonfinite(attn_out, name="attn_out") attn_out = rearrange(attn_out, "... n d -> ... (n d)") - attn_out = jnp.einsum("bsh,hd->bsd", attn_out, block.attn.w_o, out_sharding=_pbatch()) + attn_out = jnp.einsum("bsh,hd->bsd", attn_out, block.attn.w_o) hidden = hidden + attn_out mlp_in = rms_norm(hidden, block.rms_mlp, cfg.layer_norm_eps) @@ -871,7 +923,8 @@ def build(self, Vocab: Axis, *, key: PRNGKeyArray) -> GrugMoeWrapper: return GrugMoeWrapper.init(Vocab, cfg, key=key) def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: - # Rough FLOP estimate: attention + (MoE MLP per-token uses K experts). + # Match Levanter's Mixtral/OLMoE FLOP accounting: MoE MLP scales with K experts/token and + # includes the router projection when num_experts > 1. return lm_flops_per_token( hidden_dim=self.hidden_dim, intermediate_dim=self.intermediate_dim, @@ -881,6 +934,9 @@ def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: seq_len=context_length, vocab_size=vocab_size, glu=True, + num_experts=self.n_routed_experts, + num_shared_experts=0, + num_experts_per_tok=self.num_experts_per_tok, ) def total_trainable_params(self, vocab_size: int) -> int: @@ -919,7 +975,7 @@ def build_run(size: str, *, use_tpu: bool = False) -> tuple[str, SpeedrunConfig] weight_decay=0.1, steps_per_eval=500, steps_per_hf_export=-1, - explicit_mesh_axes=True, + explicit_mesh_axes=False, ) run_name = f"grugformer_moe_{size}" From acae1a6d19467430143541f990ff29410ac7fc9c Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 26 Feb 2026 21:45:40 -0800 Subject: [PATCH 3/7] speedrun: add grugformer-moe archive entry + profiling flags - Document grugformer MoE entrypoints in docs/reports/grug-archive.md - Add CLI switches for profiling, jaxpr/HLO artifact logging, and perfetto link generation - Default to legacy axis resources + non-explicit mesh axes for higher MFU parity with levanter MoE runs - Use cached Nemotron Llama3 tokenized components in olmoe_1b7b speedrun and allow CE block-size override --- docs/reports/grug-archive.md | 7 ++ ...rugformer_moe_nemotron_dclm_fineweb_10b.py | 68 +++++++++++++++++-- .../speedrun/olmoe_1b7b_nemotron_40b.py | 52 ++++++++++++-- 3 files changed, 116 insertions(+), 11 deletions(-) diff --git a/docs/reports/grug-archive.md b/docs/reports/grug-archive.md index 4f1750e26a..ae94dfe1f7 100644 --- a/docs/reports/grug-archive.md +++ b/docs/reports/grug-archive.md @@ -59,3 +59,10 @@ Copy/paste this block for new experiments: - Status: active - Purpose: Head-to-head comparison between Hackable Transformer and Grugformer (no sinks). +### grugformer-moe +- Path: `experiments/speedrun/grugformer_moe/` +- Introduced: TBD +- Last known-good: TBD +- Status: active +- Purpose: Grugformer MoE entrypoints (Mixtral/OLMoE-style router + expert MLP). +- Notes: Intended for throughput/MFU work; supports GMM (Megablox) and JAX profiling flags. diff --git a/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py b/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py index 558be3fa7a..103d7fa799 100644 --- a/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py +++ b/experiments/speedrun/grugformer_moe/grugformer_moe_nemotron_dclm_fineweb_10b.py @@ -54,6 +54,8 @@ Z_LOSS_WEIGHT = 1e-4 STEPS_PER_EVAL = 5000 STEPS_PER_EXPORT = 20_000 +DEFAULT_PROFILER_START_STEP = 5 +DEFAULT_PROFILER_NUM_STEPS = 10 _EVAL_SUITES: dict[str, tuple] = { "none": (), @@ -182,21 +184,59 @@ def _parse_args() -> argparse.Namespace: "Cross-entropy backend. 'auto' tries Pallas on TPU v5+ and falls back to XLA when unsupported (e.g. TPU v4)." ), ) - parser.set_defaults(explicit_mesh_axes=True) + parser.add_argument( + "--profile", + action="store_true", + help="Enable JAX profiling and upload `jax_profile` artifacts for Perfetto analysis.", + ) + parser.add_argument( + "--profile-start-step", + type=int, + default=DEFAULT_PROFILER_START_STEP, + help="Step to start profiling.", + ) + parser.add_argument( + "--profile-num-steps", + type=int, + default=DEFAULT_PROFILER_NUM_STEPS, + help="Number of steps to capture after profiling starts.", + ) + parser.add_argument( + "--profile-perfetto-link", + action="store_true", + help="Generate a Perfetto link when the profiler trace is finalized.", + ) + parser.add_argument( + "--log-jaxprs", + action="store_true", + help="Log the training step jaxpr to W&B artifacts (slow; off by default).", + ) + parser.add_argument( + "--log-xla-hlo", + action="store_true", + help="Log the training step StableHLO text to W&B artifacts (very slow; off by default).", + ) + # Default to non-explicit mesh axes for higher MFU on v5p (matches Levanter's MoE runs). + parser.set_defaults(explicit_mesh_axes=False) parser.add_argument( "--explicit-mesh-axes", dest="explicit_mesh_axes", action="store_true", - help="Use explicit mesh axes in TrainerConfig (default).", + help="Use explicit mesh axes in TrainerConfig.", ) parser.add_argument( "--no-explicit-mesh-axes", dest="explicit_mesh_axes", action="store_false", - help="Disable explicit mesh axes in TrainerConfig.", + help="Disable explicit mesh axes in TrainerConfig (default).", ) - parser.set_defaults(legacy_axis_resources=False) + # Default to the "legacy" DP sharding used by high-MFU Levanter MoE runs: + # token/token_repeat/batch -> (replica, data) and params sharded over embed -> data. + # + # The newer default MeshConfig maps batch over (replica_dcn, replica, data). If Grugformer hardcodes + # (replica, data) in shard_map specs, XLA will insert thousands of tiny reshard collectives. + parser.set_defaults(legacy_axis_resources=True) parser.set_defaults(use_gmm=True) parser.add_argument( "--use-gmm", @@ -227,8 +267,12 @@ def _parse_args() -> argparse.Namespace: def _patch_trainer_sharding_ablations( train_step: ExecutorStep, *, + tpu_type: str, explicit_mesh_axes: bool, legacy_axis_resources: bool, + profiler_perfetto_link: bool, + log_jaxprs: bool, + log_xla_hlo: bool, ) -> ExecutorStep: config = train_step.config inner = config.train_config @@ -246,7 +290,14 @@ def _patch_trainer_sharding_ablations( param_mapping={"embed": "data"}, ) - trainer = dataclasses.replace(trainer, mesh=mesh, use_explicit_mesh_axes=explicit_mesh_axes) + trainer = dataclasses.replace( + trainer, + mesh=mesh, + use_explicit_mesh_axes=explicit_mesh_axes, + profiler_perfetto_link=profiler_perfetto_link, + log_jaxprs=log_jaxprs, + log_xla_hlo=log_xla_hlo, + ) inner = dataclasses.replace(inner, trainer=trainer) config = dataclasses.replace(config, train_config=inner) return dataclasses.replace(train_step, config=config) @@ -377,6 +428,9 @@ def main() -> None: steps_per_hf_export=-1, per_device_parallelism=int(args.per_device_parallelism), explicit_mesh_axes=bool(args.explicit_mesh_axes), + profiler=bool(args.profile), + profiler_start_step=int(args.profile_start_step), + profiler_num_steps=int(args.profile_num_steps), ) default_suffix = f"grugformer_moe_olmoe1b7b_{tpu_type}_bs{global_batch_size}_{args.dataset}_seq{seq_len}" @@ -409,8 +463,12 @@ def main() -> None: ) train_step = _patch_trainer_sharding_ablations( train_step, + tpu_type=tpu_type, explicit_mesh_axes=bool(args.explicit_mesh_axes), legacy_axis_resources=bool(args.legacy_axis_resources), + profiler_perfetto_link=bool(args.profile_perfetto_link), + log_jaxprs=bool(args.log_jaxprs), + log_xla_hlo=bool(args.log_xla_hlo), ) steps: list[ExecutorStep] = [train_step] diff --git a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py index 3f705a7cf8..1783a85377 100644 --- a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py +++ b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py @@ -16,7 +16,7 @@ from experiments.defaults import default_train from experiments.evals.task_configs import convert_to_levanter_task_config -from experiments.pretraining_datasets import NEMOTRON_WEIGHTS, tokenize_nemotron +from experiments.pretraining_datasets import NEMOTRON_LLAMA3_OVERRIDES, NEMOTRON_WEIGHTS from experiments.pretraining_datasets.dclm import ( DCLM_MIXTURE_WEIGHTS, dclm_components_llama3, @@ -35,7 +35,7 @@ from levanter.tracker.wandb import WandbConfig from levanter.trainer import TrainerConfig from marin.execution.executor import ExecutorStep, InputName, executor_main, output_path_of -from marin.processing.tokenize import lm_data_config, lm_mixture_data_config +from marin.processing.tokenize import TokenizeConfig, lm_data_config, lm_mixture_data_config from marin.speedrun.speedrun import Author, SpeedrunConfig, SpeedrunResultsConfig, speedrun_results from marin.utilities.wandb_utils import WANDB_ENTITY, WANDB_PROJECT @@ -180,10 +180,37 @@ def build_model_config(*, model: str, seq_len: int) -> MixtralConfig: raise ValueError(f"Unknown model preset {model!r}. Options: {MODEL_OPTIONS}.") -nemotron_cc_steps = tokenize_nemotron(tokenizer=llama3_tokenizer) -nemotron_cc_mixture = lm_mixture_data_config( - components=nemotron_cc_steps, - weights=NEMOTRON_WEIGHTS, +def _resolve_dataset_path(path: str) -> str: + if "://" in path: + return path + prefix = os.environ.get("MARIN_PREFIX") + if not prefix: + return path + return f"{prefix.rstrip('/')}/{path.lstrip('/')}" + + +def _cached_nemotron_components(tokenizer: str) -> dict[str, TokenizeConfig]: + components: dict[str, TokenizeConfig] = {} + for split, relative_cache_path in NEMOTRON_LLAMA3_OVERRIDES.items(): + cache_path = _resolve_dataset_path(relative_cache_path) + components[f"nemotron_cc/{split}"] = TokenizeConfig( + train_paths=[cache_path], + validation_paths=[], + cache_path=cache_path, + tokenizer=tokenizer, + ) + return components + + +nemotron_cc_steps = _cached_nemotron_components(llama3_tokenizer) +assert nemotron_cc_steps.keys() == NEMOTRON_WEIGHTS.keys() +nemotron_cc_mixture = dataclasses.replace( + lm_mixture_data_config( + components=nemotron_cc_steps, + weights=NEMOTRON_WEIGHTS, + include_raw_paths=False, + ), + auto_build_caches=False, ) @@ -316,9 +343,12 @@ def make_speedrun_config( dataset_name: str, seq_len: int, tpu_type: str, + cross_entropy_block_size: int | None = None, ) -> SpeedrunConfig: tokenized_dataset = DATASET_OPTIONS[dataset_name] model_config = build_model_config(model=model, seq_len=seq_len) + if cross_entropy_block_size is not None: + model_config = dataclasses.replace(model_config, cross_entropy_block_size=cross_entropy_block_size) return SpeedrunConfig( author=Author( name="Marin Team", @@ -361,6 +391,15 @@ def _parse_args(): default=DEFAULT_GLOBAL_BATCH_SIZE, help="Override the global batch size (default 64).", ) + parser.add_argument( + "--cross-entropy-block-size", + type=int, + default=None, + help=( + "Override CustomMixtralConfig.cross_entropy_block_size (default: model preset default). " + "Useful for working around TPU vmem limits in the fused (Pallas) CE kernel." + ), + ) parser.add_argument( "--profile", action="store_true", @@ -454,6 +493,7 @@ def _parse_args(): dataset_name=args.dataset, seq_len=args.seq_len, tpu_type=args.tpu_type, + cross_entropy_block_size=args.cross_entropy_block_size, ) logger.info("Launching MoE Nemotron speedrun.") logger.info( From 6ed84a4924dcbd896220c04158f2031cfe0e1ffe Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 26 Feb 2026 22:30:14 -0800 Subject: [PATCH 4/7] mixtral: add equilibrium bias load balancing + router fp32 --- experiments/speedrun/custom_mixtral.py | 75 ++++++++++++++++--- lib/levanter/src/levanter/models/mixtral.py | 65 ++++++++++++++-- .../src/levanter/models/moe_load_balance.py | 74 ++++++++++++++++++ lib/levanter/tests/test_moe_load_balance.py | 50 +++++++++++++ 4 files changed, 247 insertions(+), 17 deletions(-) create mode 100644 lib/levanter/src/levanter/models/moe_load_balance.py create mode 100644 lib/levanter/tests/test_moe_load_balance.py diff --git a/experiments/speedrun/custom_mixtral.py b/experiments/speedrun/custom_mixtral.py index 441e179fa7..2f3a688cb8 100644 --- a/experiments/speedrun/custom_mixtral.py +++ b/experiments/speedrun/custom_mixtral.py @@ -32,6 +32,7 @@ from levanter.models.llama import LlamaEmbedding, LlamaMlp from levanter.models.loss import maybe_fused_next_token_loss from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel +from levanter.models.moe_load_balance import equilibrium_bias_delta_from_topk from levanter.models.mistral import MistralConfig from levanter.utils.activation import ActivationFunctionEnum from levanter.utils.flop_utils import lm_flops_per_token @@ -145,6 +146,15 @@ class CustomMixtralConfig(MistralConfig): alf_lb_use_sign: bool = True alf_lb_center_bias: bool = True + # Quantile-balancing (Optimal Allocation for Equilibrium) bias-only update proxy. + # + # We reuse `router_bias` and build a stop-gradient bias loss from the quantile update direction: + # b <- b - Quantile(r_ij, q_j) + # where residuals/quantiles are computed from adjusted logits and current top-k routing. + equilibrium_lb_loss_scale: float = 0.0 + equilibrium_lb_center_bias: bool = True + equilibrium_lb_unit_capacity: float = 1.0 + # Use dense (single-expert) routing for the first N transformer layers. dense_first_n_layers: int = 0 @@ -180,6 +190,8 @@ def __post_init__(self): assert ( self.num_experts_per_tok <= self.n_routed_experts ), f"num_experts_per_tok={self.num_experts_per_tok} greater than by n_routed_experts={self.n_routed_experts}." + if self.alf_lb_loss_scale > 0 and self.equilibrium_lb_loss_scale > 0: + raise ValueError("alf_lb_loss_scale and equilibrium_lb_loss_scale are mutually exclusive.") def hf_checkpoint_converter(self, ref_checkpoint: str | None = None) -> HFCheckpointConverter["MixtralConfig"]: # type: ignore return HFCheckpointConverter( @@ -551,10 +563,15 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam router_logits = router_logits.astype(jnp.float32) router_probs = hnn.softmax(router_logits, axis=Experts) + alf_scale = float(getattr(self.config, "alf_lb_loss_scale", 0.0)) + equilibrium_scale = float(getattr(self.config, "equilibrium_lb_loss_scale", 0.0)) selection_logits = router_logits - if getattr(self.config, "alf_lb_loss_scale", 0.0) > 0: + if alf_scale > 0 or equilibrium_scale > 0: bias = self.router_bias - if getattr(self.config, "alf_lb_center_bias", True): + center_bias = getattr(self.config, "alf_lb_center_bias", True) + if equilibrium_scale > 0: + center_bias = getattr(self.config, "equilibrium_lb_center_bias", center_bias) + if center_bias: bias = bias - hax.mean(bias, axis=Experts) if router_fp32 and bias.array.dtype != jnp.float32: bias = bias.astype(jnp.float32) @@ -567,11 +584,6 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam TopExperts, topk_then_softmax=getattr(self.config, "router_topk_then_softmax", False), ) - if router_fp32 and topk_weights.array.dtype != x_flat.array.dtype: - # Keep routing decisions in fp32, but cast weights back to the model activation dtype to - # avoid upcasting the expert weighted-sum (bandwidth/memory). - topk_weights = topk_weights.astype(x_flat.array.dtype) - if force_dense: idx_arr = jnp.zeros_like(topk_idx.array) w_arr = jnp.zeros_like(topk_weights.array) @@ -579,6 +591,28 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam topk_weights = hax.named(w_arr, (Token, TopExperts)) topk_idx = hax.named(idx_arr, (Token, TopExperts)) + equilibrium_delta = None + if equilibrium_scale > 0 and not force_dense: + delta_arr, weighted_load_arr, target_load_arr, quantile_prob_arr = equilibrium_bias_delta_from_topk( + selection_logits.array, + topk_idx.array, + topk_weights.array, + unit_capacity=float(getattr(self.config, "equilibrium_lb_unit_capacity", 1.0)), + ) + equilibrium_delta = hax.named(jax.lax.stop_gradient(delta_arr), Experts) + weighted_load = hax.named(weighted_load_arr, Experts) + quantile_prob = hax.named(quantile_prob_arr, Experts) + target_load = hax.named(jnp.full((Experts.size,), target_load_arr, dtype=weighted_load_arr.dtype), Experts) + else: + weighted_load = None + quantile_prob = None + target_load = None + + if router_fp32 and topk_weights.array.dtype != x_flat.array.dtype: + # Keep routing decisions in fp32, but cast weights back to the model activation dtype to + # avoid upcasting the expert weighted-sum (bandwidth/memory). + topk_weights = topk_weights.astype(x_flat.array.dtype) + topk_idx_flat = hax.flatten_axes(topk_idx, old_axes=[Token, TopExperts], new_axis="token_repeat") TokenRepeat = topk_idx_flat.resolve_axis("token_repeat") x_repeat_sort, group_sizes, sort_idx = self._permute(x_flat, topk_idx_flat, TokenRepeat) @@ -633,7 +667,7 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam extras = { "expert_loads": expert_loads, } - if self.config.lbl_coef is not None and getattr(self.config, "alf_lb_loss_scale", 0.0) <= 0: + if self.config.lbl_coef is not None and alf_scale <= 0 and equilibrium_scale <= 0: # Shapes: # - expert_loads: [Experts] where Experts.size == n_routed_experts # - router_probs: [Token, Experts] where Token is the flattened token axis (T = B*S) @@ -648,7 +682,6 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam axis=Token, ) - alf_scale = float(getattr(self.config, "alf_lb_loss_scale", 0.0)) if alf_scale > 0: target = TokenRepeat.size / Experts.size delta_arr = group_sizes.array - target @@ -657,6 +690,15 @@ def __call__(self, x: NamedArray, *, key=None, force_dense: bool = False) -> Nam delta_arr = jax.lax.stop_gradient(delta_arr) delta = hax.named(delta_arr, Experts) extras["alf_lb_bias_loss"] = alf_scale * hax.sum(self.router_bias * delta, axis=Experts) + if equilibrium_scale > 0 and equilibrium_delta is not None: + extras["equilibrium_lb_bias_loss"] = equilibrium_scale * hax.sum( + self.router_bias * equilibrium_delta, + axis=Experts, + ) + # Helpful diagnostics for smoke tests. + extras["equilibrium_lb_weighted_load"] = weighted_load + extras["equilibrium_lb_target_load"] = target_load + extras["equilibrium_lb_quantile_prob"] = quantile_prob return hax.unflatten_axis(out, axis=Token, new_axes=squash_axes), extras # [Batch, Pos, Embed] @@ -755,6 +797,8 @@ def __call__( extras["router_z_loss"] = hax.sum(extras["router_z_loss"], axis=self.config.Layers) if "alf_lb_bias_loss" in extras: extras["alf_lb_bias_loss"] = hax.sum(extras["alf_lb_bias_loss"], axis=self.config.Layers) + if "equilibrium_lb_bias_loss" in extras: + extras["equilibrium_lb_bias_loss"] = hax.sum(extras["equilibrium_lb_bias_loss"], axis=self.config.Layers) stats: dict[str, Array] = {} if "load_balancing_loss" in extras: stats["train/load_balancing_loss"] = jax.lax.stop_gradient(extras["load_balancing_loss"].array) @@ -762,6 +806,8 @@ def __call__( stats["train/router_z_loss"] = jax.lax.stop_gradient(extras["router_z_loss"].array) if "alf_lb_bias_loss" in extras: stats["train/alf_lb_bias_loss"] = jax.lax.stop_gradient(extras["alf_lb_bias_loss"].array) + if "equilibrium_lb_bias_loss" in extras: + stats["train/equilibrium_lb_bias_loss"] = jax.lax.stop_gradient(extras["equilibrium_lb_bias_loss"].array) if self.config.log_moe_metrics: expert_loads = extras["expert_loads"] @@ -776,6 +822,15 @@ def __call__( for j in range(self.config.n_routed_experts): stats[f"moe/layer{i}/expert{j}_load"] = jax.lax.stop_gradient(expert_loads.array[i, j]) stats["moe/load_violation_max"] = jax.lax.stop_gradient(global_load_violation_max.array) + if "equilibrium_lb_weighted_load" in extras and "equilibrium_lb_target_load" in extras: + weighted = extras["equilibrium_lb_weighted_load"].array + target = extras["equilibrium_lb_target_load"].array + rel_violation = jnp.abs((weighted - target) / jnp.maximum(target, 1e-6)) + stats["moe/equilibrium_rel_load_violation_max"] = jax.lax.stop_gradient(jnp.max(rel_violation)) + if "equilibrium_lb_quantile_prob" in extras: + stats["moe/equilibrium_quantile_prob_mean"] = jax.lax.stop_gradient( + jnp.mean(extras["equilibrium_lb_quantile_prob"].array) + ) dense_first_n_layers = int(getattr(self.config, "dense_first_n_layers", 0) or 0) if 0 < dense_first_n_layers < self.config.num_layers: sparse_layers = jnp.arange(self.config.num_layers) >= dense_first_n_layers @@ -870,6 +925,8 @@ def activations( aux_loss += extras["router_z_loss"] if "alf_lb_bias_loss" in extras: aux_loss += extras["alf_lb_bias_loss"] + if "equilibrium_lb_bias_loss" in extras: + aux_loss += extras["equilibrium_lb_bias_loss"] return x, aux_loss def get_lm_head(self) -> hax.NamedArray: diff --git a/lib/levanter/src/levanter/models/mixtral.py b/lib/levanter/src/levanter/models/mixtral.py index 73700b3625..cb71d9d150 100644 --- a/lib/levanter/src/levanter/models/mixtral.py +++ b/lib/levanter/src/levanter/models/mixtral.py @@ -28,6 +28,7 @@ from levanter.layers.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig from levanter.models.llama import LlamaEmbedding, LlamaMlp from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.moe_load_balance import equilibrium_bias_delta_from_topk from levanter.models.mistral import MistralConfig from levanter.utils.activation import ActivationFunctionEnum from levanter.utils.flop_utils import lm_flops_per_token @@ -83,6 +84,10 @@ class MixtralConfig(MistralConfig): lbl_coef: Optional[float] = 0.01 rzl_coef: Optional[float] = 0.001 + router_fp32: bool = False + equilibrium_lb_loss_scale: float = 0.0 + equilibrium_lb_center_bias: bool = True + equilibrium_lb_unit_capacity: float = 1.0 # Attention-related config upcast_attn: bool = False @@ -318,6 +323,7 @@ class MixtralSparseMoeBlock(eqx.Module): config: MistralConfig = eqx.field(static=True) gate: hnn.Linear # projection from Embed to Experts experts: MixtralMoEMlp + router_bias: NamedArray @staticmethod def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": @@ -332,8 +338,9 @@ def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": key=k_experts, use_bias=config.use_bias, ) + router_bias = hax.zeros(config.Experts) - return MixtralSparseMoeBlock(config, gate, experts) + return MixtralSparseMoeBlock(config, gate, experts, router_bias) def _route(self, router_probs: NamedArray, Token: Axis, TopExperts: Axis): @partial( @@ -440,9 +447,25 @@ def __call__(self, x: NamedArray, *, key=None) -> tuple[NamedArray, Dict[str, Na x_flat = hax.flatten_axes(x, old_axes=squash_axes, new_axis="token") # [Batch, Pos, Embed] -> [Token, Embed] Token = x_flat.resolve_axis("token") - router_logits = self.gate(x_flat, key=k_gate) + router_fp32 = bool(getattr(self.config, "router_fp32", False)) + x_for_gate = x_flat.astype(jnp.float32) if router_fp32 else x_flat + router_logits = self.gate(x_for_gate, key=k_gate) + if router_fp32 and router_logits.array.dtype != jnp.float32: + router_logits = router_logits.astype(jnp.float32) + + equilibrium_scale = float(getattr(self.config, "equilibrium_lb_loss_scale", 0.0)) + selection_logits = router_logits + if equilibrium_scale > 0: + bias = self.router_bias + if bool(getattr(self.config, "equilibrium_lb_center_bias", True)): + bias = bias - hax.mean(bias, axis=Experts) + if router_fp32 and bias.array.dtype != jnp.float32: + bias = bias.astype(jnp.float32) + selection_logits = selection_logits + bias + router_probs = hnn.softmax(router_logits, axis=Experts) - topk_weights, topk_idx = self._route(router_probs, Token, TopExperts) + selection_probs = hnn.softmax(selection_logits, axis=Experts) if equilibrium_scale > 0 else router_probs + topk_weights, topk_idx = self._route(selection_probs, Token, TopExperts) topk_idx_flat = hax.flatten_axes(topk_idx, old_axes=[Token, TopExperts], new_axis="token_repeat") TokenRepeat = topk_idx_flat.resolve_axis("token_repeat") @@ -462,7 +485,7 @@ def __call__(self, x: NamedArray, *, key=None) -> tuple[NamedArray, Dict[str, Na extras = { "expert_loads": expert_loads, } - if self.config.lbl_coef is not None: + if self.config.lbl_coef is not None and equilibrium_scale <= 0: f = expert_loads * self.config.n_routed_experts / self.config.num_experts_per_tok p = hax.mean(router_probs, axis=Token) extras["load_balancing_loss"] = self.config.lbl_coef * hax.sum(f * p, axis=Experts) @@ -470,6 +493,21 @@ def __call__(self, x: NamedArray, *, key=None) -> tuple[NamedArray, Dict[str, Na extras["router_z_loss"] = self.config.rzl_coef * hax.mean( hnn.logsumexp(router_logits, axis=Experts) ** 2, axis=Token ) + if equilibrium_scale > 0: + delta_arr, weighted_load_arr, target_load_arr, quantile_prob_arr = equilibrium_bias_delta_from_topk( + selection_logits.array, + topk_idx.array, + topk_weights.array, + unit_capacity=float(getattr(self.config, "equilibrium_lb_unit_capacity", 1.0)), + ) + delta = hax.named(jax.lax.stop_gradient(delta_arr), Experts) + extras["equilibrium_lb_bias_loss"] = equilibrium_scale * hax.sum(self.router_bias * delta, axis=Experts) + extras["equilibrium_lb_weighted_load"] = hax.named(weighted_load_arr, Experts) + extras["equilibrium_lb_target_load"] = hax.named( + jnp.full((Experts.size,), target_load_arr, dtype=weighted_load_arr.dtype), + Experts, + ) + extras["equilibrium_lb_quantile_prob"] = hax.named(quantile_prob_arr, Experts) return hax.unflatten_axis(out, axis=Token, new_axes=squash_axes), extras # [Batch, Pos, Embed] @@ -561,12 +599,21 @@ def __call__( for j in range(self.config.n_routed_experts): stats[f"moe/layer{i}/expert{j}_load"] = expert_loads.array[i, j] - if self.config.lbl_coef is not None: + if "load_balancing_loss" in extras: extras["load_balancing_loss"] = hax.sum(extras["load_balancing_loss"], axis=self.config.Layers) stats["train/load_balancing_loss"] = extras["load_balancing_loss"].array - if self.config.rzl_coef is not None: + if "router_z_loss" in extras: extras["router_z_loss"] = hax.sum(extras["router_z_loss"], axis=self.config.Layers) stats["train/router_z_loss"] = extras["router_z_loss"].array + if "equilibrium_lb_bias_loss" in extras: + extras["equilibrium_lb_bias_loss"] = hax.sum(extras["equilibrium_lb_bias_loss"], axis=self.config.Layers) + stats["train/equilibrium_lb_bias_loss"] = extras["equilibrium_lb_bias_loss"].array + if "equilibrium_lb_weighted_load" in extras and "equilibrium_lb_target_load" in extras: + rel_violation = jnp.abs( + (extras["equilibrium_lb_weighted_load"].array - extras["equilibrium_lb_target_load"].array) + / jnp.maximum(extras["equilibrium_lb_target_load"].array, 1e-6) + ) + stats["moe/equilibrium_rel_load_violation_max"] = jnp.max(rel_violation) levanter.tracker.jit_log(stats) @@ -650,10 +697,12 @@ def activations( x, extras = self.transformer(x, attn_mask=attn_mask, key=key, pos_ids=pos_ids) aux_loss: NamedArray | float = 0 - if self.config.lbl_coef is not None: + if "load_balancing_loss" in extras: aux_loss += extras["load_balancing_loss"] - if self.config.rzl_coef is not None: + if "router_z_loss" in extras: aux_loss += extras["router_z_loss"] + if "equilibrium_lb_bias_loss" in extras: + aux_loss += extras["equilibrium_lb_bias_loss"] return x, aux_loss def get_lm_head(self) -> hax.NamedArray: diff --git a/lib/levanter/src/levanter/models/moe_load_balance.py b/lib/levanter/src/levanter/models/moe_load_balance.py new file mode 100644 index 0000000000..bc73f1cd5d --- /dev/null +++ b/lib/levanter/src/levanter/models/moe_load_balance.py @@ -0,0 +1,74 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax import Array + + +def equilibrium_bias_delta_from_topk( + adjusted_router_logits: Array, + topk_idx: Array, + topk_weights: Array, + *, + unit_capacity: float = 1.0, +) -> tuple[Array, Array, Array, Array]: + """Compute one quantile-balancing bias update from top-k routing. + + This implements the single-step column update from Quantile Balancing: + - build per-token threshold from the K-th largest adjusted logit, + - compute per-expert overload residuals, + - compute target-aware quantiles from current expert loads, + - return per-expert bias deltas. + + Args: + adjusted_router_logits: `[tokens, experts]` logits after adding selection bias. + topk_idx: `[tokens, k]` selected expert indices. + topk_weights: `[tokens, k]` routing weights for selected experts. + unit_capacity: Upper-bound offset in the clipped allocation residual term. + + Returns: + A tuple `(delta, expert_weighted_load, target_load, quantile_prob)` where: + - `delta`: `[experts]` bias update to subtract (`b <- b - delta`), + - `expert_weighted_load`: `[experts]` current weighted load, + - `target_load`: scalar target load per expert (`tokens / experts`), + - `quantile_prob`: `[experts]` quantile level used per expert. + """ + if adjusted_router_logits.ndim != 2: + raise ValueError( + f"adjusted_router_logits must be rank-2 [tokens, experts], got {adjusted_router_logits.shape}" + ) + if topk_idx.ndim != 2 or topk_weights.ndim != 2: + raise ValueError("topk_idx and topk_weights must be rank-2 [tokens, k]") + if topk_idx.shape != topk_weights.shape: + raise ValueError(f"topk_idx shape {topk_idx.shape} must match topk_weights shape {topk_weights.shape}") + + tokens, experts = adjusted_router_logits.shape + if topk_idx.shape[0] != tokens: + raise ValueError("topk tensors must have same token dimension as adjusted_router_logits") + + k = topk_idx.shape[1] + logits = adjusted_router_logits.astype(jnp.float32) + idx = topk_idx.astype(jnp.int32) + weights = topk_weights.astype(jnp.float32) + + kth = jax.lax.top_k(logits, k)[0][:, -1] + residual = jnp.maximum(logits - kth[:, None] - unit_capacity, 0.0) + + flat_idx = idx.reshape(-1) + flat_weights = weights.reshape(-1) + expert_weighted_load = jnp.bincount(flat_idx, weights=flat_weights, length=experts) + + tokens_f = jnp.asarray(tokens, dtype=jnp.float32) + target_load = tokens_f / float(experts) + quantile_prob = jnp.clip(1.0 - (expert_weighted_load - target_load) / tokens_f, 0.0, 1.0) + + sorted_residual = jnp.sort(residual, axis=0) + quantile_pos = jnp.floor(quantile_prob * float(tokens - 1)).astype(jnp.int32) + quantile_pos = jnp.clip(quantile_pos, 0, tokens - 1) + expert_ids = jnp.arange(experts, dtype=jnp.int32) + delta = sorted_residual[quantile_pos, expert_ids] + + return delta.astype(adjusted_router_logits.dtype), expert_weighted_load, target_load, quantile_prob diff --git a/lib/levanter/tests/test_moe_load_balance.py b/lib/levanter/tests/test_moe_load_balance.py new file mode 100644 index 0000000000..e16cd1ded9 --- /dev/null +++ b/lib/levanter/tests/test_moe_load_balance.py @@ -0,0 +1,50 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import jax.numpy as jnp + +from levanter.models.moe_load_balance import equilibrium_bias_delta_from_topk + + +def test_equilibrium_bias_delta_shapes_and_ranges(): + logits = jnp.array( + [ + [5.0, 3.0, 0.0], + [4.8, 3.2, 0.1], + [5.2, 3.1, -0.1], + ], + dtype=jnp.float32, + ) + topk_idx = jnp.array([[0, 1], [0, 1], [0, 1]], dtype=jnp.int32) + topk_weights = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.85, 0.15]], dtype=jnp.float32) + + delta, weighted_load, target_load, quantile_prob = equilibrium_bias_delta_from_topk(logits, topk_idx, topk_weights) + + assert delta.shape == (3,) + assert weighted_load.shape == (3,) + assert quantile_prob.shape == (3,) + assert target_load.shape == () + assert float(target_load) == 1.0 + assert jnp.all(quantile_prob >= 0.0) + assert jnp.all(quantile_prob <= 1.0) + + +def test_equilibrium_bias_delta_penalizes_overloaded_expert(): + logits = jnp.array( + [ + [5.0, 3.0, 0.0], + [5.1, 3.1, 0.0], + [5.2, 3.0, 0.0], + [5.3, 3.1, 0.0], + ], + dtype=jnp.float32, + ) + topk_idx = jnp.array([[0, 1], [0, 1], [0, 1], [0, 1]], dtype=jnp.int32) + topk_weights = jnp.array([[0.95, 0.05], [0.9, 0.1], [0.95, 0.05], [0.9, 0.1]], dtype=jnp.float32) + + delta, weighted_load, _, _ = equilibrium_bias_delta_from_topk(logits, topk_idx, topk_weights) + + # Expert 0 is overloaded and should receive the strongest downward update. + assert weighted_load[0] > weighted_load[1] + assert delta[0] > 0.0 + assert delta[0] >= delta[1] From 21de48542ba7f4990e2dcd9cc67f521b9560651e Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 26 Feb 2026 22:33:17 -0800 Subject: [PATCH 5/7] speedrun: add olmoe_s preset for v4 smoke test --- .../speedrun/olmoe_1b7b_nemotron_40b.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py index 1783a85377..cb4c980058 100644 --- a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py +++ b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py @@ -123,8 +123,9 @@ def run_levanter_checkpoint_eval_harness(config: LevanterEvalHarnessStepConfig) DEFAULT_PROFILER_NUM_STEPS = 20 MODEL_OLMOE_1B7B = "olmoe_1b7b" +MODEL_OLMOE_S = "olmoe_s" MODEL_MIXTRAL_8X7B = "mixtral_8x7b" -MODEL_OPTIONS = (MODEL_OLMOE_1B7B, MODEL_MIXTRAL_8X7B) +MODEL_OPTIONS = (MODEL_OLMOE_1B7B, MODEL_OLMOE_S, MODEL_MIXTRAL_8X7B) DEFAULT_MODEL = MODEL_OLMOE_1B7B OLMOE_1B7B_REFERENCE_CHECKPOINT = "allenai/OLMoE-1B-7B-0125" @@ -150,6 +151,28 @@ def _build_olmoe_1b7b_config(seq_len: int) -> MixtralConfig: ) +def _build_olmoe_s_config(seq_len: int) -> MixtralConfig: + """Small OLMoE-style config (~125M active params) for smoke testing (e.g. v4-8).""" + return MixtralConfig( + seq_len=seq_len, + hidden_dim=384, + intermediate_dim=768, + num_layers=8, + num_heads=6, + num_kv_heads=3, + n_routed_experts=8, + num_experts_per_tok=1, + layer_norm_epsilon=1e-5, + gradient_checkpointing=True, + scan_layers=True, + use_gmm=True, + # v4 TPUs have a small scoped-vmem budget; keep the fused CE kernel's tiling conservative by default. + cross_entropy_block_size=512, + reference_checkpoint=OLMOE_1B7B_REFERENCE_CHECKPOINT, + tokenizer=OLMOE_1B7B_REFERENCE_CHECKPOINT, + ) + + def _build_mixtral_8x7b_config(seq_len: int) -> MixtralConfig: """Mixtral 8x7B config (8 experts, 2 routed/token), aligned with MaxText's model geometry.""" return MixtralConfig( @@ -175,6 +198,8 @@ def _build_mixtral_8x7b_config(seq_len: int) -> MixtralConfig: def build_model_config(*, model: str, seq_len: int) -> MixtralConfig: if model == MODEL_OLMOE_1B7B: return _build_olmoe_1b7b_config(seq_len) + if model == MODEL_OLMOE_S: + return _build_olmoe_s_config(seq_len) if model == MODEL_MIXTRAL_8X7B: return _build_mixtral_8x7b_config(seq_len) raise ValueError(f"Unknown model preset {model!r}. Options: {MODEL_OPTIONS}.") From f5982ed6f862b3f86fdfaf391b98f902479a195f Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Thu, 26 Feb 2026 22:41:31 -0800 Subject: [PATCH 6/7] mixtral: omit router_bias from HF state dict --- experiments/speedrun/custom_mixtral.py | 20 ++++++++++++++++++++ lib/levanter/src/levanter/models/mixtral.py | 20 ++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/experiments/speedrun/custom_mixtral.py b/experiments/speedrun/custom_mixtral.py index 2f3a688cb8..1e8b5563ad 100644 --- a/experiments/speedrun/custom_mixtral.py +++ b/experiments/speedrun/custom_mixtral.py @@ -18,6 +18,7 @@ import haliax as hax import haliax.nn as hnn +import haliax.state_dict as hstate from haliax import Axis, NamedArray from haliax._src.scan import ScanCheckpointSpec from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split @@ -434,6 +435,25 @@ def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": return MixtralSparseMoeBlock(config, gate, experts, router_bias) + def to_state_dict(self, prefix: str | None = None) -> StateDict: + """Torch-compatible serialization. + + `router_bias` is a Levanter-only parameter, so it is intentionally omitted from the exported state dict. + """ + state_dict: StateDict = {} + state_dict.update(hstate.to_state_dict(self.gate, prefix=hstate.with_prefix(prefix, "gate"))) + state_dict.update(hstate.to_state_dict(self.experts, prefix=hstate.with_prefix(prefix, "experts"))) + return state_dict + + def from_state_dict(self, state_dict: StateDict, prefix: str | None = None) -> "MixtralSparseMoeBlock": + """Torch-compatible deserialization. + + `router_bias` is not present in HF checkpoints; keep the initialized value. + """ + gate = hstate.from_state_dict(self.gate, state_dict, prefix=hstate.with_prefix(prefix, "gate")) + experts = hstate.from_state_dict(self.experts, state_dict, prefix=hstate.with_prefix(prefix, "experts")) + return eqx.tree_at(lambda m: (m.gate, m.experts), self, (gate, experts)) + def _route( self, selection_logits: NamedArray, diff --git a/lib/levanter/src/levanter/models/mixtral.py b/lib/levanter/src/levanter/models/mixtral.py index cb71d9d150..f343363404 100644 --- a/lib/levanter/src/levanter/models/mixtral.py +++ b/lib/levanter/src/levanter/models/mixtral.py @@ -16,6 +16,7 @@ import haliax as hax import haliax.nn as hnn +import haliax.state_dict as hstate from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.normalization import LayerNormBase @@ -342,6 +343,25 @@ def init(config: MistralConfig, *, key) -> "MixtralSparseMoeBlock": return MixtralSparseMoeBlock(config, gate, experts, router_bias) + def to_state_dict(self, prefix: str | None = None) -> StateDict: + """Torch-compatible serialization. + + `router_bias` is a Levanter-only parameter, so it is intentionally omitted from the exported state dict. + """ + state_dict: StateDict = {} + state_dict.update(hstate.to_state_dict(self.gate, prefix=hstate.with_prefix(prefix, "gate"))) + state_dict.update(hstate.to_state_dict(self.experts, prefix=hstate.with_prefix(prefix, "experts"))) + return state_dict + + def from_state_dict(self, state_dict: StateDict, prefix: str | None = None) -> "MixtralSparseMoeBlock": + """Torch-compatible deserialization. + + `router_bias` is not present in HF checkpoints; keep the initialized value. + """ + gate = hstate.from_state_dict(self.gate, state_dict, prefix=hstate.with_prefix(prefix, "gate")) + experts = hstate.from_state_dict(self.experts, state_dict, prefix=hstate.with_prefix(prefix, "experts")) + return eqx.tree_at(lambda m: (m.gate, m.experts), self, (gate, experts)) + def _route(self, router_probs: NamedArray, Token: Axis, TopExperts: Axis): @partial( shard_map, From da2f3e36341c24f0af5b4ced9ac8acc1ca6e6916 Mon Sep 17 00:00:00 2001 From: Pranshu Chaturvedi Date: Fri, 27 Feb 2026 00:00:08 -0800 Subject: [PATCH 7/7] speedrun: honor --seq-len in training --- experiments/speedrun/olmoe_1b7b_nemotron_40b.py | 1 + 1 file changed, 1 insertion(+) diff --git a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py index cb4c980058..6c6c7069c3 100644 --- a/experiments/speedrun/olmoe_1b7b_nemotron_40b.py +++ b/experiments/speedrun/olmoe_1b7b_nemotron_40b.py @@ -385,6 +385,7 @@ def make_speedrun_config( train_config=SimpleTrainConfig( resources=ResourceConfig.with_tpu(tpu_type=tpu_type), train_batch_size=global_batch_size, + train_seq_len=seq_len, num_train_steps=num_train_steps, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY,