Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,10 @@ DO NOT:
## Environment

- Prefer to use `uv` when possible. If you can't (for instance, due to sandbox restrictions) you can use `.venv/bin/python`

## Ray Run Notes

- In shared clusters, `OwnerDiedError`, raylet "missed too many heartbeats" messages, and autoscaler resize logs can be noise from other workloads. Prefer judging health by your *job's* step logs/status and whether `run_on_pod_ray` is retrying preemptions.
- If a job appears stuck before logging metrics, it is often waiting on TPU slice scheduling (e.g. `SliceActor`/`TPUHostActor` pending creation). Use `uv run python scripts/ray/cluster.py --config infra/marin-us-central1.yaml list-jobs` and `... job-logs <submission_id>` to confirm.
- Grugformer MoE smoke runs: prefer `--smoke --dataset nemotron_cc --tpu-type v5p-16 --seq-len 1024 --global-batch-size 32 --num-train-steps 20 --dataset-tokenizer meta-llama/Meta-Llama-3.1-8B --legacy-axis-resources`. The launcher defaults to fused (Pallas) CE; `xla` CE will materialize full logits and can OOM at realistic token counts.
- Grugformer MoE experts: default to the Megablox GMM pathway (`--use-gmm`). The ragged-dot pathway can trigger huge HBM temporaries during compile (e.g. expert-linear shapes like `bf16[64,262144,1024]`) and crash TPU workers; use `--no-use-gmm` only for debugging/ablations.
7 changes: 7 additions & 0 deletions docs/reports/grug-archive.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ Copy/paste this block for new experiments:
- Status: active
- Purpose: Head-to-head comparison between Hackable Transformer and Grugformer (no sinks).

### grugformer-moe
- Path: `experiments/speedrun/grugformer_moe/`
- Introduced: TBD
- Last known-good: TBD
- Status: active
- Purpose: Grugformer MoE entrypoints (Mixtral/OLMoE-style router + expert MLP).
- Notes: Intended for throughput/MFU work; supports GMM (Megablox) and JAX profiling flags.
13 changes: 10 additions & 3 deletions experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def default_train(
eval_harness_tasks: Sequence[EvalTaskConfig] = CORE_TASKS,
wandb_name: str | None = None,
wandb_group: str | None = None,
wandb_project: str | None = None,
override_output_path: str | None = None,
checkpointer_save_interval: timedelta | None = None,
checkpointer_keep: list[dict] | None = None,
) -> ExecutorStep:
"""
Train a language model using the default configuration.
Expand All @@ -300,6 +303,10 @@ def default_train(
eval_harness_tasks: List of evaluation harness tasks. Defaults to the CORE set of tasks. Use () or [] to disable
wandb_name: Optional W&B display name for this run. Defaults to W&B's auto-generated name.
wandb_group: Optional W&B group to organize related runs (e.g., a sweep). If unset, defaults to $WANDB_GROUP.
wandb_project: Optional W&B project name. Defaults to "marin" when unset.
checkpointer_save_interval: Optional override for the checkpointer time-based save interval.
checkpointer_keep: Optional override for the checkpointer step-based keep policies. Passing an empty list keeps
only time-based (temporary) checkpoints.
"""

pretraining_data = _prepare_data_config(tokenized, use_default_validation)
Expand Down Expand Up @@ -346,7 +353,7 @@ def default_train(
data=pretraining_data,
trainer=TrainerConfig(
tracker=WandbConfig(
project="marin",
project=wandb_project or "marin",
name=wandb_name,
tags=[*tags],
group=wandb_group,
Expand All @@ -358,8 +365,8 @@ def default_train(
num_train_steps=train_config.num_train_steps,
steps_per_eval=train_config.steps_per_eval if train_config.steps_per_eval is not None else 1000,
checkpointer=CheckpointerConfig(
save_interval=timedelta(minutes=10),
keep=[dict(every=steps_per_export)],
save_interval=checkpointer_save_interval or timedelta(minutes=10),
keep=checkpointer_keep if checkpointer_keep is not None else [dict(every=steps_per_export)],
),
model_averaging=model_averaging,
mesh=MeshConfig(
Expand Down
Loading