Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements Direct Preference Optimization (DPO) training support for Levanter/Marin. The implementation was generated by Codex and reviewed/simplified by Claude, with comprehensive documentation provided in DPO_claude.md explaining the rationale for all changes.
Changes:
- Added complete DPO training implementation with policy and reference models
- Extended data processing to handle preference chat datasets
- Added minimal but necessary Haliax changes to handle NamedArray as leaf nodes in tree operations
- Added comprehensive test coverage for all new functionality
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| lib/levanter/src/levanter/main/train_dpo.py | Core DPO training loop with loss computation and model management (515 new lines) |
| lib/levanter/src/levanter/data/text.py | PreferenceChatProcessor and PreferencePairDataset for handling preference data |
| lib/levanter/src/levanter/data/packing.py | Added "drop" slice_strategy for sequences exceeding max length |
| lib/levanter/src/levanter/trainer_state.py | Uses is_leaf pattern with NamedArray for correct partition/combine operations |
| lib/haliax/src/haliax/quantization.py | Treats NamedArray as leaf in partition/apply_updates operations |
| lib/haliax/src/haliax/partitioning.py | Handles NamedArray with array=None and batch_dim from vmap |
| lib/haliax/src/haliax/nn/scan.py | Adds auto_sharded after vmap for memory efficiency with stacked layers |
| lib/levanter/tests/test_dpo.py | Comprehensive test suite (372 lines) covering all DPO functionality |
| lib/marin/src/marin/training/training.py | Marin integration for DPO training via TrainDpoOnPodConfig |
| lib/marin/src/marin/transform/conversation/transform_preference_data.py | Preference dataset transformation with fsspec support |
| experiments/exp2101_dpo_ultrafeedback.py | Example experiment using DPO on Ultrafeedback dataset |
| lib/levanter/config/dpo_ultrafeedback_llama3_8b.yaml | Production-ready DPO configuration |
| lib/levanter/config/dpo_tiny_gpt2.yaml | Minimal DPO test configuration |
| tokenized=tokenized_preferences, | ||
| model_config=llama_8b, | ||
| dpo_config=dpo_config, | ||
| tags=["ultrafeedback", "llama3", "simpo"], |
There was a problem hiding this comment.
The tags list includes "simpo" but this appears to be a DPO (Direct Preference Optimization) implementation, not SimPO (Simple Preference Optimization). If this is intentional (perhaps planning to implement SimPO later), consider adding a comment to clarify. Otherwise, remove the "simpo" tag.
| tags=["ultrafeedback", "llama3", "simpo"], | |
| tags=["ultrafeedback", "llama3"], |
experiments/defaults.py
Outdated
| pretraining_data = dataclasses.replace(pretraining_data, permutation_type="feistel") | ||
| vocab_size = _get_vocab_size(pretraining_data) | ||
|
|
||
| if len(name) > 64: |
There was a problem hiding this comment.
can we extract a helper for this since we use it for training too
| # this happens when we filter out params for things like lora. | ||
| # could use eqx.partition to avoid this, but eh | ||
| return named | ||
| if getattr(named.array, "batch_dim", None) is not None: |
There was a problem hiding this comment.
i hate this. i need a minimum reproducer so i can make this go away
lib/marin/src/marin/transform/conversation/transform_preference_data.py
Outdated
Show resolved
Hide resolved
lib/marin/src/marin/transform/conversation/transform_preference_data.py
Outdated
Show resolved
Hide resolved
lib/marin/src/marin/transform/conversation/transform_preference_data.py
Outdated
Show resolved
Hide resolved
| weight_decay: float = 0.0 | ||
| warmup: float = 0.03 | ||
| cooldown: float | None = None | ||
| lr_schedule: str = "linear" |
There was a problem hiding this comment.
linear yes, the particular value of warm up here I'm not sure but it's probably good to make it close to train lm
|
|
||
| Note that trainer.id and the RUN_ID env variable take precedence, in that order. | ||
| """ | ||
| allow_out_of_region: tuple[str, ...] = () |
There was a problem hiding this comment.
let's not allow this until we really need it
There was a problem hiding this comment.
I was inheriting from TrainLMPod... do we want to get rid of that too? I can do that
- Update train_dpo.py imports to use LmDataConfig instead of SingleDatasetLMConfig - Migrate to components-based data config structure - Replace text.py with text/ package structure (from simpo) - Add preference.py with DPO-specific classes - Update DPO YAML configs to use components: structure - Merge validation split functions into single _build_validation_split - Copy updated Levanter main scripts from simpo (train_lm.py, eval_lm.py, etc.) - Copy updated marin tokenize files from simpo Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Copy all config files from simpo (updated to components: structure) - Copy updated source files from simpo (trainer_state.py, optim/, etc.) - Add EpochDataset class to dataset.py for DPO training - Update text/__init__.py exports for preference functions - Add SimPO config files from simpo branch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Keep preference format handling in datasets.py and formats.py - Keep DPO exports in text/__init__.py - Accept main's partitioning.py changes (use axis_names) - Restore EpochDataset class in dataset.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
… branch - tokenizer: marin-community/marin-tokenizer - train_batch_size: 128, num_train_steps: 2150 - learning_rate: 5e-7, lr_schedule: cosine, warmup: 0.1 - beta: 0.01 - Add both train and validation components with proper cache dirs - Use GCS model paths for reference_model_path and initialize_from_hf - validation_split_fraction: null (uses separate validation component) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2025 The Marin Authors | |||
There was a problem hiding this comment.
in keeping with new policy i'm not gonna review this unless you want me to
There was a problem hiding this comment.
that said, can you $AGENT up a doc that explains how to run dpo in the Levanter and Marin settings (probably two docs, one for marin and one for levanter)
| # this happens when we filter out params for things like lora. | ||
| # could use eqx.partition to avoid this, but eh | ||
| return named | ||
| if getattr(named.array, "batch_dim", None) is not None: |
| ) | ||
| else: | ||
| # Check for preference format (imported lazily to avoid circular imports) | ||
| from .preference import PreferenceChatLmDatasetFormat, dataset_for_preference_format |
There was a problem hiding this comment.
preference datasets should not be part of text datasets. they are a different structure (two sequences instead of one) and so need to be a different type.
| return model.policy if isinstance(model, DpoModel) else model | ||
|
|
||
|
|
||
| def _bool_tree_like(tree, value: bool): |
There was a problem hiding this comment.
in theory any tree prefix should work so you should be able to return value directly but maybe i'm wrong
There was a problem hiding this comment.
TIL thank you! looks a lot cleaner
| *, | ||
| beta: float, | ||
| ) -> tuple[jnp.ndarray, dict[str, Metric]]: | ||
| if isinstance(delta_pi, hax.NamedArray) or isinstance(delta_ref, hax.NamedArray): |
There was a problem hiding this comment.
do we need to make this this defensive
There was a problem hiding this comment.
im just gonna make everything named array
| nll = model.compute_next_token_loss(example, reduction=None, reduction_axis=(), key=key) | ||
| Pos = example.tokens.resolve_axis("position") | ||
| return -hax.sum(nll, axis=Pos) |
There was a problem hiding this comment.
actually why aren't we just doing
| nll = model.compute_next_token_loss(example, reduction=None, reduction_axis=(), key=key) | |
| Pos = example.tokens.resolve_axis("position") | |
| return -hax.sum(nll, axis=Pos) | |
| nll = model.compute_next_token_loss(example, reduction=hax.sum, reduction_axis="position", key=key) | |
| return -nll |
| if cache is None: | ||
| raise ValueError(f"No training cache available for component {name}.") | ||
|
|
||
| if not isinstance(component.format, PreferenceChatLmDatasetFormat): |
There was a problem hiding this comment.
we should change the code until this isn't possible (i.e. not a textdataset)
| loss, metrics = dpo_loss_from_logps(delta_pi, delta_ref, beta=config.beta) | ||
| chosen_reward = (logp_pi_chosen - logp_ref_chosen) * config.beta | ||
| rejected_reward = (logp_pi_rejected - logp_ref_rejected) * config.beta | ||
| if isinstance(chosen_reward, hax.NamedArray): |
There was a problem hiding this comment.
can we standardize on named or unnamed
There was a problem hiding this comment.
standardizing on named
| state = dataclasses.replace(state, model=None) | ||
| gc.collect() |
There was a problem hiding this comment.
this is wrong and not preemption/resume safe. if you load a model from trainer.initial_state you need to stick iwth it unless step== 0
There was a problem hiding this comment.
sorry not sure why this was added
MONITOR_SIMPO.md
Outdated
| # Check for preference format (imported lazily to avoid circular imports) | ||
| from .preference import PreferenceChatLmDatasetFormat, preprocessor_for_preference_format | ||
|
|
||
| if isinstance(format, PreferenceChatLmDatasetFormat): | ||
| return preprocessor_for_preference_format(format, tokenizer) # type: ignore |
| if cache is None: | ||
| raise ValueError(f"No training cache available for component {name}.") | ||
|
|
||
| if not isinstance(component.format, PreferenceChatLmDatasetFormat): |
|
|
||
| Note that trainer.id and the RUN_ID env variable take precedence, in that order. | ||
| """ | ||
| allow_out_of_region: tuple[str, ...] = () |
dlwh
left a comment
There was a problem hiding this comment.
Codex review (by Codex): merged latest main into this DPO branch and resolved defaults.py conflict by keeping DPO defaults wiring while aligning tokenizer vocab-size resolution with current main utilities. I recommend re-running the full CI matrix due the large main-sync delta.
I had codex write a DPO implementation and then claude double check / simplify. From looking @DPO_CLAUDE.md file nothing seems obviously wrong to me, this reference model + policy model business is quite strange