- Entry point:
train_switch_bank.pyorchestrates distributed setup, logging, and environment flags. It captures source for reproducibility via the top-levelcodestring and patches Torch Inductor'strace_structuredto be metadata-tolerant while logging compiled filenames. Logs and checkpoints go underlog_dir(defaults torecords/track_2_medium/2025-12-26_SwitchBankin Hyperparameters). train_gpt_medium_w_grad_accum.pylogs training loss to W&B undertrain/loss_mainfor parity with switch-bank runs.- Optuna wrapper:
optuna_router_tune.pyruns single-GPU tuning over router temp/logit-cap schedules by callingtrain_switch_bank.run_trainingwith overrides; it reuses the compiled model across trials and early-stops after the logit-cap ramp for a validation+logit score (loss normalized to a reference step), logging per-run metrics to its own wandb project plus an overview project. - Backfill helper:
optuna_overview_backfill.pycan repopulate the overview wandb run fromoptuna_results/trial_*.jsonwithout duplicating logged trials. - Loss calibration:
loss_adjustment.pybuildsoptuna_results/loss_adjustment.jsonfrom W&B val/loss history, fitting a global floor and bucketized (A,p) curves; other Optuna helpers read this file to adjust losses. - Focus of the repo: a sideways MoE GPT variant where each transformer layer routes to a shared bank of FFN experts. Non-standard optimizations include Muon (hybrid AdamW + Muon/TurboMuon spectral optimizer) for 2D+ bf16 matrices, FlexAttention, router feature EMAs (with clamped/strided alphas), router temperature/logit-cap schedules, optional router Gumbel noise, adapter support on routers, configurable router boost shapes, router/FFN freeze controls, and mid-training checkpoint/resume.
- Ignore/never touch: any
*_original*.pyortrain_gpt_*.pyfiles unless explicitly requested.
utils.py: numeric safety helpers (_sanitize/_safe_softmax) plus scheduling utilities (next_multiple_of_n,rampdown_multiplier,compute_train_micro_lenwhich enforces 128-token blocks). Includes a placeholdersummarizehelper. Keep behavior identical for numeric parity.optim/muon.py: Muon optimizer. Runs AdamW for non-spectral params and Muon (Turbo off) or TurboMuon-style spectral updates (Turbo on) for 2D+ bf16 matrices. Spectral branch LR (lr_spec) defaults to the group'slrunless explicitly overridden.model/components.py:Rotary,CausalSelfAttention(FlexAttention + RMSNorm + rotary).SharedFFNBank: shared expert W1/W2 across layers, per-layer routers with optional adapters; forward/reverse EMA features (blockwise, doc-aware reverse) cached perema_layer_stridewith clampable/freezeable alphas (reverse EMA head window is zeroed to avoid non-causal leakage). Extra wandb logging records reverse-EMA head/tail fractions per batch. Top-k hard routing with optional Gumbel noise, active/pruned expert masks, deterministic top-1 when only one expert is live. Aux loss mixes load/importance CV² with an entropy-gap term. Metrics buffer stores load vectors, entropies, and feature-weight means;compile_warm_all_expertswarms kernels; adapters lazily initialize;prune_inactive_expertszeros weights/adapters when invoked.Block: per-layer wrapper combining skip/SA lambdas, optional attention, temperature/load shaping by layer position, and bank call; supports boost shapes (peak,valley,linear_start,linear_end) and adecay_scalemultiplier.
model/gpt.py:GPTmodel wiring embeds (token table padded to 128) with optional value embeds spread across layers; blocks + shared bank; scalar lambdas; router schedules (temperature/logit caps, expert activation masks, optional Gumbel, optional router temp schedule end to hold temperature flat past a step); router/adapters/FFN freeze fractions and EMA clamping; flag builder (EOD/after-EOD/first-docN/block-pos bins); document-causal blockmask construction mixing long/short windows; LM head tie/untie logic with runtime retie.- Helpers
_compute_router_temp,_second_expert_step. - Optional layer tying:
layer_tie_groupsties attention weights and router adapters across specified layer groups (default pairs short-only layers like (9,10), (13,14), (17,18), (21,22), (25,26)).
data.py: binary shard loader/generator, router metric summarizers, androuter_summary_strformatter. Generator does not wrap when shards are exhausted and supports resume viaskip_batches.trainer.py: training/validation loop and schedules (LR, cubic window size with optional schedule end step, router temp/logit-cap, Gumbel). Handles grad accumulation, all-reduce, router-only grad clipping with optional AutoClip (10th-percentile over a 250-step window; can warm up unclipped when base clip is 0), rampdowns (router/adapter/FFN), Muon momentum warmup, wandb/CSV logging gated by flags, abort checks, mid-training checkpoints (model+optimizers+meta+approx step timing), LM head untie logging, and dataloader fast-forward on resume. Validation supports separate token budgets viaval_tokens_intermediate(non-final steps) andval_tokens_final(final step), plus an initial step-0 validation using the intermediate budget; optional early-stop can be treated as final for validation/checkpointing. Final checkpoint saving can be gated on a loss threshold viasave_final_checkpoint_if_loss_belowandsave_final_checkpoint_max_loss. Router logging includes feature-weight percentages, normalized CV², entropy gaps, usage gap, per-layer stats, and a composite router health metric.
train_switch_bank.pysets up distributed env/logging, disables donated buffers/compiled autograd, patches Inductor tracing, captures module source viacode, and inits wandb/CSV when enabled.- Builds
GPTwith hyperparams (shared bank size/stride/window config, EMA settings, router/adapters, value embeds, router Gumbel/boost shape, expert activation schedule), broadcasts params, and logs parameter counts. - Partitions params into Muon groups: non-spectral (embeds/scalars/router/adapters/head) use AdamW branch; spectral (attention QKV/out and shared FFN bank matrices) use Muon/TurboMuon spectral branch. Stores
initial_lrper group. - Optional checkpoint resume validates meta (dims/experts/vocab), restores approx step timing, sanitizes Muon state dtypes, and recompiles the model (
torch.compiledynamic=False). Computes block-alignedtrain_micro_len(logs adjustments). - Optional warmup (synthetic steps +
compile_warm_all_experts) while preserving optimizer/model state, then training/validation viatrainer.run_training(window schedule, router temp/logit-cap schedules, Gumbel gating, freeze milestones, optional mid-training checkpoint viacheckpoint_save_step, wandb/CSV logging, resume-aware dataloader position). - Shutdown: report peak memory, destroy process group, finish wandb/CSV.
- Each layer routes tokens to a shared expert bank via per-layer routers. Features include token norm, optional forward/reverse EMA contexts (blockwise, doc-aware reverse) shared across layer groups via
ema_layer_stride, and flags (EOD/after-EOD/first-docN/block-pos bins). - Routing: top-k hard switch with optional Gumbel noise; outputs scaled by gate probs. Aux loss combines load/importance CV² plus an entropy-gap penalty. Deterministic top-1 when only one expert is active after masking/pruning; router/adapters/EMA alphas can be frozen late in training.
- Adapters: optional per-layer/expert scale/bias applied pre-FFN; lazy init using means of active adapters; zeroed when pruned.
- Pruning:
prune_inactive_expertsis available to zero weights/adapters based on activity but is not invoked by the trainer. - Schedules: router temperature/logit-cap curves anchored to second-expert activation; router/adapter/FFN LR rampdowns and freeze fractions; expert activation masks; router boost shape controls per-layer temp/lb scaling (peak/valley/linear_start/linear_end).
- Preserve numeric parity: keep helper call order, defaults, and routing logic unchanged when modifying modules.
- When logging or reproducing runs, ensure
codeaggregation and the Inductor trace patch intrain_switch_bank.pystay aligned with module contents. - Hyperparameters live in
Hyperparameters(train_switch_bank.py); adjust, and keepcompute_train_micro_lenblock alignment intact. train_switch_bank.run_trainingsupports single-GPU execution, env JSON overrides (SWB_OVERRIDES_JSON/SWITCH_BANK_OVERRIDES_JSON), early-stop validation, and reuse-state caching for compile/warmup in wrappers.- Use
switch_bank/trainer.pyfor any training-loop changes; keeptrain_switch_bank.pyas orchestration only and respect logging gates (enable_extra_logging,enable_extra_wandb_logging). - When making or reverting changes that affect behavior, logging, checkpointing, or instructions, update
AGENTS.mdaccordingly.