Skip to content

Commit 8d752a7

Browse files
dlwhmoojink
andauthored
Grug-native template cleanup and legacy path retirement (#3054)
## Scope This PR is the full grug-native/template transition slice on this branch (not just a small cleanup). ## What changed - Promote `experiments/grug/base/` as the canonical grug edit surface: - add/expand `experiments/grug/base/model.py`, `experiments/grug/base/train.py`, `experiments/grug/base/launch.py` - add `experiments/grug/README.md` and package init files - Retire legacy Grugformer-era library paths and old comparison script: - remove `lib/levanter/src/levanter/grug/main.py` - remove `lib/levanter/src/levanter/grug/data.py` - remove `lib/levanter/src/levanter/models/grug_wrapper.py` - remove `experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py` - Remove obsolete grugformer-focused test suite and replace with template-focused coverage: - remove `lib/levanter/tests/grug/test_grugformer*.py` - add/update `tests/test_grug_base_template.py` - Add callback state-adapter path needed by grug template training: - add `lib/levanter/src/levanter/callbacks/state_adapter.py` - update callbacks/tensorstore callback wiring and tests - Include supporting runtime/parity adjustments used by the template flow: - `lib/levanter/src/levanter/eval.py` - `lib/levanter/src/levanter/utils/jax_utils.py` (+ tests) - minor compatibility updates in `lib/levanter/src/levanter/compat/hf_checkpoints.py` - minor fused CE API touch in `lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py` - Update project/docs guidance to template-first grug workflow: - `.agents/projects/grugformer.md` - `docs/recipes/change_grug.md` - `docs/reports/grug-archive.md` ## Validation run on this branch - `uv run pytest tests/test_eval.py` (from `lib/levanter`) - `uv run pytest tests/test_grug_base_template.py` - `uv run python infra/pre-commit.py --all-files` ## Notes - This PR intentionally contains the accumulated grug-native transition work on `codex/grug-native-template-cleanup`. - Local scratch/monitoring files were not included. --------- Co-authored-by: Moo Jin Kim <moojink@stanford.edu>
1 parent cf99ce5 commit 8d752a7

29 files changed

+1514
-1711
lines changed

.agents/projects/grugformer.md

Lines changed: 32 additions & 211 deletions
Large diffs are not rendered by default.

docs/recipes/change_grug.md

Lines changed: 51 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,81 @@
1-
# Recipe: Changing Grug (Experiment → Canonical)
1+
# Recipe: Changing Grug (Template-First)
22

3-
Grug is meant to be “grug-simple” and easy to hack, but we still want a single, trustworthy “best guess” implementation in `levanter.grug`.
3+
Grug is intentionally template-first: the canonical edit surface lives in `experiments/grug/base/`, not in a shared `levanter.grug` trainer stack.
44

55
This recipe describes the workflow for:
66

7-
1) trying changes safely in a speedrun experiment, and
8-
2) upstreaming successful ideas into the canonical core (and cleaning up old experiments).
7+
1. trying a change in an experiment copy, and
8+
2. upstreaming it into the base template when it proves out.
99

10-
## Source Of Truth vs Experiments
10+
## Source Of Truth
1111

12-
- **Source of truth:** `lib/levanter/src/levanter/grug/`
13-
- This is the “best guess” model. It should stay small, readable, and testable.
14-
- **Evolving experiment:** `experiments/speedrun/nano_arch_ablations/00_baseline/main.py`
15-
- This is the *living* entrypoint that is expected to evolve as we learn.
16-
- **One-off experiments:** under `experiments/speedrun/…`
17-
- These are snapshots / specialized edit surfaces (e.g. attention sinks).
12+
- **Canonical template:** `experiments/grug/base/`
13+
- `model.py`
14+
- `train.py`
15+
- `launch.py`
16+
- **Variants:** `experiments/grug/<variant>/`
17+
- copy from `base` and modify locally (for example MoE).
18+
- **One-off speedruns:** `experiments/speedrun/...`
19+
- useful for exploration, not canonical.
1820

19-
We try not to let one-off scripts become the canonical implementation.
21+
## Workflow
2022

21-
## When You Want To Try Something
23+
### 1) Pick one change bucket
2224

23-
### 1) Decide what you’re changing
25+
Keep each pass scoped to one bucket:
2426

25-
Most changes fall into one bucket:
27+
- attention/masking
28+
- block wiring/norm ordering
29+
- MLP/activation
30+
- loss kernel behavior
31+
- optimizer/training loop behavior
2632

27-
- **Attention** (masking semantics, kernels, sinks/aux, layout/sharding)
28-
- **Block** (residual wiring, normalization order, pre/post-norm)
29-
- **MLP** (activation, GLU variants, gating, dimension choices)
30-
- **Loss** (large-vocab CE, z-loss, label smoothing, logit soft-cap)
31-
- **Optimizer** (Adam, Muon, etc.)
33+
### 2) Experiment in a copy
3234

33-
Try to change **one bucket at a time**. Optimizer isn't really (currently) addressed by Grug, but we'll get there.
35+
- Copy `experiments/grug/base` to a new variant directory.
36+
- Keep edits local and explicit (copy/paste over abstraction).
37+
- Avoid introducing reusable framework surface unless there's clear repeated use.
3438

35-
### 2) Create an experiment entrypoint
39+
### 3) Record the experiment
3640

37-
Start from:
41+
Update `docs/reports/grug-archive.md` with:
3842

39-
- `experiments/speedrun/nano_arch_ablations/00_baseline/main.py`
43+
- path
44+
- commit SHA (when known)
45+
- purpose
46+
- status (`active`, `superseded`, `deleted`)
4047

41-
Recommended workflow:
48+
### 4) Upstream to base if it wins
4249

43-
1. Copy the file to a new experiment (or branch the baseline if the change is small):
44-
- Example: `experiments/speedrun/<idea>/main.py`
45-
2. Keep the edit surface explicit:
46-
- If you’re changing attention, keep the change in one local `attention()` or `attn_fn` block.
47-
- If you’re changing the MLP, keep it local to an `mlp()` helper.
48-
3. Avoid introducing new abstractions (this is a speedrun file; copy/paste is fine).
50+
Port the successful change back into:
4951

50-
### 3) Register the experiment in the archive
52+
- `experiments/grug/base/model.py`
53+
- `experiments/grug/base/train.py`
54+
- `experiments/grug/base/launch.py`
5155

52-
Add an entry to:
56+
Keep it grug-style:
5357

54-
- `docs/reports/grug-archive.md`
58+
- plain JAX arrays and explicit sharding
59+
- Equinox modules with `init` + `__call__`
60+
- minimal config knobs
61+
- keep legibility first; if a block gets hard to read, introduce a small local helper instead of adding framework indirection
5562

56-
Record:
57-
- the experiment path,
58-
- the commit SHA (once known),
59-
- what you changed and why,
60-
- the intended “status” (`active`, `superseded`, `deleted`).
63+
### 5) Delete stale paths
6164

62-
## When You Want To Adopt Something As Canonical
65+
After upstreaming:
6366

64-
### 1) Port to `levanter.grug`
67+
- delete superseded experiment code,
68+
- keep only the archive trail in `docs/reports/grug-archive.md`.
6569

66-
Move the change into one of the core files:
70+
### 6) Validate
6771

68-
- `lib/levanter/src/levanter/grug/attention.py`
69-
- `lib/levanter/src/levanter/grug/model.py`
70-
- `lib/levanter/src/levanter/grug/loss.py`
72+
Run the relevant checks:
7173

72-
Keep the “grug” style:
73-
- Equinox modules for model components (`Transformer`, `Block`, `MLP`, `RMSNorm`, attention),
74-
- explicit module pattern per component: `@staticmethod init(...)` + `__call__(...)`,
75-
- model API on module methods (`Transformer.init`, `Transformer.__call__`, `Transformer.logits`, `Transformer.next_token_loss`),
76-
- explicit sharding when needed (and loud failures otherwise).
77-
78-
### 2) Update/extend tests
79-
80-
Add or adjust tests to lock the intended surface:
81-
82-
- `lib/levanter/tests/grug/test_grugformer_core.py`
83-
- `lib/levanter/tests/grug/test_grugformer_model_loss.py`
84-
- `lib/levanter/tests/grug/test_grugformer_compilation.py`
85-
86-
The goal is:
87-
- shapes don’t regress,
88-
- `jit` still works,
89-
- sharding doesn’t explode,
90-
- mask semantics remain correct.
91-
92-
### 3) Clean up old experiments
93-
94-
After merging a canonical improvement:
95-
96-
- If an experiment is now redundant and not referenced, **delete it** and mark it `deleted` in `docs/reports/grug-archive.md`.
97-
- If an experiment represents a meaningful historical run, keep it but mark it `superseded`, and point to the canonical change (or the new experiment).
98-
Do this only if it's not going to be a maintenance burden.
99-
100-
Prefer “archive entry + deletion” over keeping lots of old code in-tree.
101-
102-
### 4) Run repo checks
103-
104-
Before sending the PR:
105-
106-
```sh
74+
```bash
10775
uv run python infra/pre-commit.py --all-files
76+
uv run pytest tests/test_grug_base_template.py
10877
```
10978

110-
## Notes / Inspiration
79+
Add any additional focused tests needed for behavior changes.
11180

112-
This workflow is inspired by projects like `modded-nanogpt`: keep a small, readable core, iterate quickly via “hackable” entrypoints, and regularly upstream what works.
81+
This workflow is inspired by modded-nanogpt: iterate quickly in copy-paste experiments, then upstream only what stays simple and useful.

docs/reports/grug-archive.md

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,39 @@
11
# Grug Archive: Experiments and Snapshots
22

3-
This file is a lightweight “paper trail for Grug-related experiments, inspired by the idea of keeping a runnable history without letting a pile of one-off scripts become the de facto source of truth.
3+
This file is the paper trail for grug experiments.
44

55
## Principles
66

7-
- **`levanter.grug` is the source of truth.** Speedrun files are snapshots/entrypoints, not the canonical implementation.
8-
- **Every experiment should be attributable to a commit.** If an experiment is removed or superseded, it should be clear what replaced it and why.
9-
- **Prefer deletion over permanent snapshots.** If a script is dead, delete it and record the last known-good commit here.
10-
- **Keep diffs small.** When an experiment is kept “alive”, update it to track the current core rather than forking the entire model.
11-
12-
## When Grug Core Changes
13-
14-
When a change in `levanter.grug` is likely to affect results, performance, or semantics:
15-
16-
1. Update the experiment(s) that should track “best guess”.
17-
2. For experiments that no longer make sense:
18-
- delete them, or
19-
- mark them superseded and point to the replacement.
20-
3. Update the corresponding entry in this archive (and any linked issue).
7+
- `experiments/grug/base/` is the canonical template.
8+
- Speedrun files are exploratory and may be deleted after upstreaming.
9+
- Prefer deletion over long-term maintenance of stale experiment code.
2110

2211
## Entry Template
2312

24-
Copy/paste this block for new experiments:
25-
2613
```text
2714
### <experiment-id>
28-
- Path: `<repo-relative-path>`
15+
- Path: <repo-relative-path>
2916
- Introduced: <commit-sha>
3017
- Last known-good: <commit-sha>
3118
- Status: active | superseded | deleted
3219
- Purpose: <one line>
33-
- Notes: <optional; what changed, how to reproduce, caveats>
34-
- Superseded by: <experiment-id or commit-sha; optional>
35-
- Issue: <url or issue id; optional>
20+
- Superseded by: <path or commit; optional>
21+
- Issue: <url/id; optional>
3622
```
3723

3824
## Experiments
3925

40-
### grugformer-attnsink
41-
- Path: `experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py`
42-
- Introduced: TBD
43-
- Last known-good: TBD
44-
- Status: active
45-
- Purpose: “Hackable” Grug attention-sink variant; intended edit surface for sinks/aux.
46-
- Notes: Keep this file short; copy/paste local modifications rather than growing new abstractions.
47-
48-
### grugformer-starter-speedrun
49-
- Path: `experiments/speedrun/grugformer_starter/grugformer_speedrun.py`
26+
### grug-base-template
27+
- Path: `experiments/grug/base/`
5028
- Introduced: TBD
5129
- Last known-good: TBD
5230
- Status: active
53-
- Purpose: Minimal starter speedrun for Grug; convenient baseline for quick iteration.
31+
- Purpose: canonical grug template (model/train/launch).
5432

5533
### grugformer-vs-hackable-125m
5634
- Path: `experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py`
5735
- Introduced: TBD
5836
- Last known-good: TBD
59-
- Status: active
60-
- Purpose: Head-to-head comparison between Hackable Transformer and Grugformer (no sinks).
61-
37+
- Status: deleted
38+
- Purpose: historical head-to-head comparison.
39+
- Superseded by: template-first workflow centered on `experiments/grug/base/`.

experiments/grug/README.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Grug Layout and Usage
2+
3+
`experiments/grug/` is template-first. You edit experiment code here directly.
4+
5+
## Directory layout
6+
7+
- `base/model.py`: model config and model implementation (`init` + `__call__` + loss method).
8+
- `base/train.py`: train loop, optimizer step, callbacks, eval/checkpoint wiring.
9+
- `base/launch.py`: experiment config and execution entrypoint (`ExecutorStep` + resources).
10+
11+
## Entry-point guide
12+
13+
- Start in `base/launch.py` for normal run edits.
14+
- `GrugBaseLaunchConfig` is the user-facing knob surface (model/data/optimizer/trainer/eval/run metadata).
15+
- `versioned(...)` marks config values that should affect executor step version/hash.
16+
- `this_output_path()` resolves to the current step's output root.
17+
- `run_grug(...)` in `base/train.py` is the runtime entry point used by the `ExecutorStep`.
18+
- `P` in train/model code is the usual JAX alias for `PartitionSpec`; see the JAX explicit sharding tutorial: [Explicit Sharding (JAX)](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html).
19+
20+
## How to use it
21+
22+
1. Copy `experiments/grug/base` to a new variant directory (for example `experiments/grug/moe`).
23+
2. Make model/training changes in that variant, not in shared trainer libraries.
24+
3. Set run knobs in `<variant>/launch.py` (run id, data mix, optimizer, TPU type).
25+
4. Launch from the variant's `launch.py` entrypoint.
26+
27+
## Quickstart launch
28+
29+
Local executor run:
30+
31+
```bash
32+
uv run python experiments/grug/base/launch.py
33+
```
34+
35+
Ray cluster run:
36+
37+
```bash
38+
uv run lib/marin/src/marin/run/ray_run.py \
39+
--env_vars WANDB_API_KEY=${WANDB_API_KEY} \
40+
-- python experiments/grug/base/launch.py
41+
```
42+
43+
## Common edit points
44+
45+
- Architecture changes: `experiments/grug/base/model.py`
46+
- Train-loop and callback behavior: `experiments/grug/base/train.py`
47+
- Run config, resources, and launch wiring: `experiments/grug/base/launch.py`
48+
- Copy-paste variant workflow: duplicate `experiments/grug/base/` into `experiments/grug/<variant>/` and edit there.
49+
50+
## Trainer knobs people ask about
51+
52+
- `z_loss_weight` in `GrugTrainerConfig`: weight on the logsumexp stabilization term in LM loss.
53+
- `ema_beta` in `GrugTrainerConfig`: exponential moving average (EMA) coefficient for eval/checkpoint model; `None` disables EMA.
54+
55+
## Checkpoints and resume
56+
57+
- Checkpoints are written to `<output_path>/checkpoints` by default in `base/launch.py`.
58+
- `run_grug` restores from `trainer.load_checkpoint_path` when set, otherwise tries the run checkpoint path.
59+
- If `trainer.load_checkpoint=True` and no checkpoint is found, startup fails; otherwise it starts from scratch.
60+
61+
## Environment variables you will likely use
62+
63+
- `WANDB_API_KEY`: required for W&B logging in the default launch config.
64+
- `GRUG_RUN_ID`: overrides the default run id.
65+
- `FERRY_DATE`: appended to run id for ferry-style launches.
66+
67+
## Where outputs show up
68+
69+
- Training/eval metrics: tracker backend (default W&B).
70+
- Checkpoints: `<output_path>/checkpoints`.
71+
- Profiler traces (if enabled): `<trainer.log_dir>/<run_id>/profiler`.
72+
- Executor step outputs: `this_output_path()` root for the step.
73+
74+
## Logged metrics
75+
76+
- `train/loss`: training loss for the just-completed step.
77+
- `global_step`: completed optimizer step index.
78+
- `run_progress`: completed fraction of the configured run (`step / total_steps`).
79+
- `optim/*`: optimizer hyperparameters from Optax state (for example `optim/learning_rate`).
80+
- `throughput/duration`: step wall-clock duration (after loss is materialized).
81+
- `throughput/examples_per_second`: examples processed per second for the current batch size.
82+
- `throughput/tokens_per_second`: tokens processed per second.
83+
- `throughput/total_tokens`: cumulative tokens processed so far (schedule-aware).
84+
- `throughput/gflops_per_second`: model FLOP throughput from analytic FLOPs-per-example.
85+
- `throughput/total_gflops`: cumulative model FLOPs (analytic).
86+
- `throughput/mfu`: model FLOP utilization as percent of theoretical hardware FLOPs.
87+
- `throughput/hook_time`: callback/logging overhead time after each step.
88+
- `throughput/loading_time`: dataloader wait time for the current step.
89+
- `throughput/flops_per_token_analytic`: analytic FLOPs per token summary value.
90+
- `throughput/flops_per_example_analytic`: analytic FLOPs per example summary value.
91+
- `throughput/flops_per_example`: FLOPs-per-example value used by throughput callback.
92+
- `throughput/device_kind`: accelerator type string from JAX device info.
93+
- `throughput/theoretical_flops_per_device`: theoretical peak FLOPs per device.
94+
- `throughput/theoretical_flops`: theoretical peak FLOPs across all devices.
95+
- `mixture/stage`: current data-mixture stage index.
96+
- `mixture/weight/<dataset_name>`: effective sampling weight per dataset in the active stage.
97+
- `eval/loss`, `eval/loading_time`, `eval/total_time`: tagged-eval loss and timing for current model.
98+
- `eval/ema/*`: same eval metrics for EMA weights when EMA is enabled.
99+
- `eval/macro_loss`: macro average loss across tags when multiple tags exist.
100+
- `eval/<tag>/loss`, `eval/<tag>/micro_loss`, `eval/<tag>/macro_loss`: per-tag loss views.
101+
- `eval/bpb`, `eval/macro_bpb`, `eval/<tag>/bpb`, `eval/<tag>/macro_bpb`: bits-per-byte metrics when tokenizer/BPB logging is enabled.
102+
- `grad/*`, `params/*`, `updates/*`, `opt_state/*`: optional watch metrics (norms/histograms) when watch is enabled.
103+
104+
## What should stay consistent
105+
106+
- Keep core training/eval metrics aligned with classic Levanter (`train/loss`, `throughput/*`, `eval/*`).
107+
- Prefer shared helpers only for generic infrastructure; keep variant behavior local to the template.
108+
109+
## Further guidance
110+
111+
- Grug principles: [`/.agents/projects/grugformer.md`](../../.agents/projects/grugformer.md)
112+
- Change workflow: [`/docs/recipes/change_grug.md`](../../docs/recipes/change_grug.md)
113+
- Executor mechanics: [`/docs/explanations/executor.md`](../../docs/explanations/executor.md)
114+
- Executor tutorial: [`/docs/tutorials/executor-101.md`](../../docs/tutorials/executor-101.md)
115+
- TPU debug workflow: [`/docs/dev-guide/dev_tpu.md`](../../docs/dev-guide/dev_tpu.md)
116+
- Cluster launch details: [`/docs/tutorials/tpu-cluster-setup.md`](../../docs/tutorials/tpu-cluster-setup.md)

experiments/grug/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright 2025 The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0

experiments/grug/base/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright 2025 The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)