Skip to content

Comments

First Attempt at DPO#2460

Open
ahmeda14960 wants to merge 42 commits intomainfrom
dpo_claude_opus
Open

First Attempt at DPO#2460
ahmeda14960 wants to merge 42 commits intomainfrom
dpo_claude_opus

Conversation

@ahmeda14960
Copy link
Contributor

I had codex write a DPO implementation and then claude double check / simplify. From looking @DPO_CLAUDE.md file nothing seems obviously wrong to me, this reference model + policy model business is quite strange

@ahmeda14960 ahmeda14960 requested review from Copilot and dlwh and removed request for Copilot January 24, 2026 20:09
Copilot AI review requested due to automatic review settings January 24, 2026 20:15
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements Direct Preference Optimization (DPO) training support for Levanter/Marin. The implementation was generated by Codex and reviewed/simplified by Claude, with comprehensive documentation provided in DPO_claude.md explaining the rationale for all changes.

Changes:

  • Added complete DPO training implementation with policy and reference models
  • Extended data processing to handle preference chat datasets
  • Added minimal but necessary Haliax changes to handle NamedArray as leaf nodes in tree operations
  • Added comprehensive test coverage for all new functionality

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
lib/levanter/src/levanter/main/train_dpo.py Core DPO training loop with loss computation and model management (515 new lines)
lib/levanter/src/levanter/data/text.py PreferenceChatProcessor and PreferencePairDataset for handling preference data
lib/levanter/src/levanter/data/packing.py Added "drop" slice_strategy for sequences exceeding max length
lib/levanter/src/levanter/trainer_state.py Uses is_leaf pattern with NamedArray for correct partition/combine operations
lib/haliax/src/haliax/quantization.py Treats NamedArray as leaf in partition/apply_updates operations
lib/haliax/src/haliax/partitioning.py Handles NamedArray with array=None and batch_dim from vmap
lib/haliax/src/haliax/nn/scan.py Adds auto_sharded after vmap for memory efficiency with stacked layers
lib/levanter/tests/test_dpo.py Comprehensive test suite (372 lines) covering all DPO functionality
lib/marin/src/marin/training/training.py Marin integration for DPO training via TrainDpoOnPodConfig
lib/marin/src/marin/transform/conversation/transform_preference_data.py Preference dataset transformation with fsspec support
experiments/exp2101_dpo_ultrafeedback.py Example experiment using DPO on Ultrafeedback dataset
lib/levanter/config/dpo_ultrafeedback_llama3_8b.yaml Production-ready DPO configuration
lib/levanter/config/dpo_tiny_gpt2.yaml Minimal DPO test configuration

tokenized=tokenized_preferences,
model_config=llama_8b,
dpo_config=dpo_config,
tags=["ultrafeedback", "llama3", "simpo"],
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tags list includes "simpo" but this appears to be a DPO (Direct Preference Optimization) implementation, not SimPO (Simple Preference Optimization). If this is intentional (perhaps planning to implement SimPO later), consider adding a comment to clarify. Otherwise, remove the "simpo" tag.

Suggested change
tags=["ultrafeedback", "llama3", "simpo"],
tags=["ultrafeedback", "llama3"],

Copilot uses AI. Check for mistakes.
pretraining_data = dataclasses.replace(pretraining_data, permutation_type="feistel")
vocab_size = _get_vocab_size(pretraining_data)

if len(name) > 64:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we extract a helper for this since we use it for training too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# this happens when we filter out params for things like lora.
# could use eqx.partition to avoid this, but eh
return named
if getattr(named.array, "batch_dim", None) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i hate this. i need a minimum reproducer so i can make this go away

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we revert this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes done

weight_decay: float = 0.0
warmup: float = 0.03
cooldown: float | None = None
lr_schedule: str = "linear"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this what people do for dpo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear yes, the particular value of warm up here I'm not sure but it's probably good to make it close to train lm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed warmup


Note that trainer.id and the RUN_ID env variable take precedence, in that order.
"""
allow_out_of_region: tuple[str, ...] = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not allow this until we really need it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was inheriting from TrainLMPod... do we want to get rid of that too? I can do that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

dlwh added a commit that referenced this pull request Jan 28, 2026
together with #2463 should avoid a lot of the noisy changes in
#2460/#2462
ahmeda14960 and others added 8 commits January 30, 2026 23:23
- Update train_dpo.py imports to use LmDataConfig instead of SingleDatasetLMConfig
- Migrate to components-based data config structure
- Replace text.py with text/ package structure (from simpo)
- Add preference.py with DPO-specific classes
- Update DPO YAML configs to use components: structure
- Merge validation split functions into single _build_validation_split
- Copy updated Levanter main scripts from simpo (train_lm.py, eval_lm.py, etc.)
- Copy updated marin tokenize files from simpo

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Copy all config files from simpo (updated to components: structure)
- Copy updated source files from simpo (trainer_state.py, optim/, etc.)
- Add EpochDataset class to dataset.py for DPO training
- Update text/__init__.py exports for preference functions
- Add SimPO config files from simpo branch

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Keep preference format handling in datasets.py and formats.py
- Keep DPO exports in text/__init__.py
- Accept main's partitioning.py changes (use axis_names)
- Restore EpochDataset class in dataset.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
… branch

- tokenizer: marin-community/marin-tokenizer
- train_batch_size: 128, num_train_steps: 2150
- learning_rate: 5e-7, lr_schedule: cosine, warmup: 0.1
- beta: 0.01
- Add both train and validation components with proper cache dirs
- Use GCS model paths for reference_model_path and initialize_from_hf
- validation_split_fraction: null (uses separate validation component)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@ahmeda14960 ahmeda14960 requested a review from dlwh February 2, 2026 07:13
@@ -0,0 +1,96 @@
# Copyright 2025 The Marin Authors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in keeping with new policy i'm not gonna review this unless you want me to

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that said, can you $AGENT up a doc that explains how to run dpo in the Levanter and Marin settings (probably two docs, one for marin and one for levanter)

# this happens when we filter out params for things like lora.
# could use eqx.partition to avoid this, but eh
return named
if getattr(named.array, "batch_dim", None) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we revert this

)
else:
# Check for preference format (imported lazily to avoid circular imports)
from .preference import PreferenceChatLmDatasetFormat, dataset_for_preference_format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preference datasets should not be part of text datasets. they are a different structure (two sequences instead of one) and so need to be a different type.

return model.policy if isinstance(model, DpoModel) else model


def _bool_tree_like(tree, value: bool):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory any tree prefix should work so you should be able to return value directly but maybe i'm wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL thank you! looks a lot cleaner

*,
beta: float,
) -> tuple[jnp.ndarray, dict[str, Metric]]:
if isinstance(delta_pi, hax.NamedArray) or isinstance(delta_ref, hax.NamedArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to make this this defensive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im just gonna make everything named array

Comment on lines 87 to 89
nll = model.compute_next_token_loss(example, reduction=None, reduction_axis=(), key=key)
Pos = example.tokens.resolve_axis("position")
return -hax.sum(nll, axis=Pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually why aren't we just doing

Suggested change
nll = model.compute_next_token_loss(example, reduction=None, reduction_axis=(), key=key)
Pos = example.tokens.resolve_axis("position")
return -hax.sum(nll, axis=Pos)
nll = model.compute_next_token_loss(example, reduction=hax.sum, reduction_axis="position", key=key)
return -nll

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if cache is None:
raise ValueError(f"No training cache available for component {name}.")

if not isinstance(component.format, PreferenceChatLmDatasetFormat):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should change the code until this isn't possible (i.e. not a textdataset)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you?

loss, metrics = dpo_loss_from_logps(delta_pi, delta_ref, beta=config.beta)
chosen_reward = (logp_pi_chosen - logp_ref_chosen) * config.beta
rejected_reward = (logp_pi_rejected - logp_ref_rejected) * config.beta
if isinstance(chosen_reward, hax.NamedArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we standardize on named or unnamed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

standardizing on named

Comment on lines 407 to 408
state = dataclasses.replace(state, model=None)
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wrong and not preemption/resume safe. if you load a model from trainer.initial_state you need to stick iwth it unless step== 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry not sure why this was added

MONITOR_SIMPO.md Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put in .agents/ or docs/ please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

@ahmeda14960 ahmeda14960 requested a review from dlwh February 8, 2026 17:57
Comment on lines +278 to +282
# Check for preference format (imported lazily to avoid circular imports)
from .preference import PreferenceChatLmDatasetFormat, preprocessor_for_preference_format

if isinstance(format, PreferenceChatLmDatasetFormat):
return preprocessor_for_preference_format(format, tokenizer) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

if cache is None:
raise ValueError(f"No training cache available for component {name}.")

if not isinstance(component.format, PreferenceChatLmDatasetFormat):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you?


Note that trainer.id and the RUN_ID env variable take precedence, in that order.
"""
allow_out_of_region: tuple[str, ...] = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codex review (by Codex): merged latest main into this DPO branch and resolved defaults.py conflict by keeping DPO defaults wiring while aligning tokenizer vocab-size resolution with current main utilities. I recommend re-running the full CI matrix due the large main-sync delta.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants