Skip to content
Draft

Dpo #2462

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from levanter.compat.hf_checkpoints import load_tokenizer
from levanter.data.text import LmDatasetFormatBase, LMMixtureDatasetConfig, TextLmDatasetFormat
from levanter.eval_harness import LmEvalHarnessConfig
from levanter.main.train_dpo import TrainDpoConfig
from levanter.main.train_lm import TrainLmConfig
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig
Expand All @@ -47,6 +48,7 @@
)
from experiments.llama import compute_num_parameters
from experiments.paloma import paloma_tokenized
from experiments.simple_dpo_config import SimpleDPOConfig
from experiments.simple_sft_config import SimpleSFTConfig
from experiments.simple_train_config import SimpleTrainConfig
from marin.download.huggingface.download_hf import DownloadConfig, download_hf
Expand All @@ -69,7 +71,9 @@
)
from marin.processing.tokenize.tokenize import HfTokenizeConfig, TokenizeConfigBase
from marin.training.training import (
TrainDpoOnPodConfig,
TrainLmOnPodConfig,
run_levanter_train_dpo,
run_levanter_train_lm,
)

Expand Down Expand Up @@ -513,6 +517,148 @@ def default_sft(
)


def default_dpo(
name: str,
tokenized: InputName | ExecutorStep | LMMixtureDatasetConfig,
model_config: LlamaConfig,
dpo_config: SimpleDPOConfig,
tags: Sequence[str] = (),
override_output_path: str | None = None,
) -> ExecutorStep:
"""
Creates an ExecutorStep for DPO fine-tuning.

Args:
name: The name of the training run, forms the basis of the output path.
tokenized: The tokenized preference data to train on.
model_config: Levanter LlamaConfig for the model architecture to train.
dpo_config: Configuration for the DPO training process.
tags: Additional tags for WandB logging. Default: ().
override_output_path: Optional override for executor output path.
"""
if "dpo" not in tags:
tags = [*tags, "dpo"]

initialize_from_hf = dpo_config.initialize_from_hf

if initialize_from_hf is None:
initialize_from_hf = (
dpo_config.model_name_or_path is not None and dpo_config.initialize_from_checkpoint_path is None
)
elif initialize_from_hf is True and dpo_config.model_name_or_path is None:
raise ValueError("initialize_from_hf is True but model_name_or_path is not set")
elif initialize_from_hf is False and dpo_config.initialize_from_checkpoint_path is None:
raise ValueError("initialize_from_hf is False but initialize_from_checkpoint_path is not set")

pretraining_data = _prepare_data_config(tokenized, use_default_validation=False)
pretraining_data = dataclasses.replace(pretraining_data, permutation_type="feistel")
vocab_size = _get_vocab_size(pretraining_data)

if len(name) > 64:
old_name = name
if "-" not in name:
name = name[:64]
else:
prefix, suffix = name.rsplit("-", 1)
if len(suffix) >= 64:
suffix = suffix[:64]
name = suffix
else:
name = prefix[: 63 - len(suffix)] + "-" + suffix
logger.warning(f"Truncated name from {old_name} to {name} to fit within WANDB limits.")

steps_per_export = dpo_config.steps_per_checkpoint
if dpo_config.steps_per_hf_export is None:
steps_per_export_hf = steps_per_export
elif dpo_config.steps_per_hf_export == -1:
steps_per_export_hf = None
else:
steps_per_export_hf = dpo_config.steps_per_hf_export

actual_model_config = unwrap_versioned_value(model_config)
train_length = dpo_config.train_seq_len or actual_model_config.max_seq_len
if train_length > actual_model_config.max_seq_len:
raise ValueError(f"train_length {train_length} exceeds model max_seq_len {actual_model_config.max_seq_len}.")

schedule = BatchSchedule(unwrap_versioned_value(dpo_config.train_batch_size))
total_examples = schedule.global_data_offset_by_step(dpo_config.num_train_steps)

reference_model_path = dpo_config.reference_model_path or dpo_config.model_name_or_path
if reference_model_path is None:
raise ValueError("reference_model_path must be set for DPO training.")

inner_config = TrainDpoConfig(
data=pretraining_data,
trainer=TrainerConfig(
tracker=WandbConfig(
project=dpo_config.wandb_project or "marin",
tags=[*tags],
),
mp=jmp.get_policy("p=f32,c=bfloat16"),
train_batch_size=dpo_config.train_batch_size,
num_train_steps=dpo_config.num_train_steps,
steps_per_eval=dpo_config.steps_per_eval,
checkpointer=CheckpointerConfig(
save_interval=timedelta(minutes=10),
keep=[dict(every=steps_per_export)],
),
model_averaging=None,
mesh=MeshConfig(
compute_mapping={
"token": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA),
"token_repeat": (ResourceAxis.REPLICA_DCN, ResourceAxis.REPLICA, ResourceAxis.DATA),
}
),
allow_partial_checkpoint=dpo_config.allow_partial_checkpoint,
allow_nondivisible_batch_size=True,
quantization=QuantizationConfig(int8=dpo_config.int8) if dpo_config.int8 else None,
initialize_from=None,
),
initialize_from_checkpoint_path=dpo_config.initialize_from_checkpoint_path,
initialize_from_hf=dpo_config.model_name_or_path if initialize_from_hf else False,
train_seq_len=train_length,
model=model_config,
optimizer=AdamConfig(
learning_rate=dpo_config.learning_rate,
weight_decay=dpo_config.weight_decay,
warmup=dpo_config.warmup,
decay=dpo_config.cooldown,
lr_schedule=dpo_config.lr_schedule,
min_lr_ratio=dpo_config.min_lr_ratio,
max_grad_norm=dpo_config.max_grad_norm,
),
reference_model_path=reference_model_path,
reference_is_hf=dpo_config.reference_is_hf,
beta=dpo_config.beta,
validation_split_fraction=dpo_config.validation_split_fraction,
hf_save_steps=steps_per_export_hf,
hf_save_dtype=dpo_config.hf_save_dtype,
data_seed=dpo_config.seed,
)

config = TrainDpoOnPodConfig(
train_config=inner_config,
resources=dpo_config.resources,
output_path=this_output_path(),
)

model_config = unwrap_versioned_value(model_config)

return ExecutorStep(
name=os.path.join("checkpoints", name),
description=(
f"Train a {compute_num_parameters(model_config, vocab_size):,} parameter model for "
f"{dpo_config.num_train_steps} (steps) * "
f"{dpo_config.train_batch_size} (batch_size) * "
f"{train_length} (train_seq_len) "
f"= {total_examples * train_length} tokens."
),
fn=run_levanter_train_dpo,
config=config,
override_output_path=override_output_path,
)


@lru_cache
def _cached_load_tokenizer(tokenizer_name: str):
return load_tokenizer(tokenizer_name)
Expand Down
95 changes: 95 additions & 0 deletions experiments/exp2101_dpo_ultrafeedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 The Marin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Run DPO on the Ultrafeedback preference dataset using Marin's executor framework.
"""

from levanter.data.text import PreferenceChatLmDatasetFormat

from experiments.defaults import default_dpo, default_tokenize
from experiments.llama import llama_8b
from experiments.marin_models import marin_tokenizer
from experiments.posttrain.preference_datasets import get_preference_dataset
from experiments.simple_dpo_config import SimpleDPOConfig
from fray.cluster import ResourceConfig
from marin.execution.executor import executor_main
from marin.processing.tokenize import lm_data_config

DATASET_NAME = "HuggingFaceH4/ultrafeedback_binarized"
LLAMA3_8B_HF_PATH = "gs://marin-us-central1/gcsfuse_mount/models/meta-llama--Llama-3-1-8B--main"

preference_dataset = get_preference_dataset(DATASET_NAME, splits=["train_prefs", "test_prefs"])

tokenized_train_preferences = default_tokenize(
name="ultrafeedback_binarized_train_prefs_marin_tokenizer",
dataset=preference_dataset / "train_prefs/*.jsonl.gz",
tokenizer=marin_tokenizer,
format=PreferenceChatLmDatasetFormat(),
)

tokenized_test_preferences = default_tokenize(
name="ultrafeedback_binarized_test_prefs_marin_tokenizer",
dataset=preference_dataset / "test_prefs/*.jsonl.gz",
tokenizer=marin_tokenizer,
format=PreferenceChatLmDatasetFormat(),
is_validation=True,
)

tokenized_preferences = lm_data_config(
training_set=tokenized_train_preferences,
validation_sets={"ultrafeedback_test_prefs": tokenized_test_preferences},
)

dpo_config = SimpleDPOConfig(
resources=ResourceConfig.with_tpu("v5p-16"),
train_batch_size=128,
num_train_steps=2150,
learning_rate=5e-7,
lr_schedule="cosine",
warmup=0.1,
cooldown=None,
wandb_project="dpo",
tokenizer=marin_tokenizer,
model_name_or_path=LLAMA3_8B_HF_PATH,
reference_model_path=LLAMA3_8B_HF_PATH,
reference_is_hf=True,
train_seq_len=4096,
max_seq_len=4096,
beta=0.01,
validation_split_fraction=None,
steps_per_eval=200,
steps_per_checkpoint=1000,
steps_per_hf_export=1000,
seed=0,
)

training_step = default_dpo(
name="dpo/ultrafeedback_llama3_8b",
tokenized=tokenized_preferences,
model_config=llama_8b,
dpo_config=dpo_config,
tags=["ultrafeedback", "llama3", "simpo"],
)


if __name__ == "__main__":
executor_main(
steps=[
preference_dataset,
tokenized_train_preferences,
tokenized_test_preferences,
training_step,
]
)
4 changes: 2 additions & 2 deletions experiments/posttrain/preference_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

Current datasets:
1. HuggingFaceH4/ultrafeedback_binarized
(only train_prefs split included to avoid accidentally training on test_prefs)
(train_prefs and test_prefs splits included; keep them separate in downstream training)
2. allenai/olmo-2-1124-7b-preference-mix
"""

Expand Down Expand Up @@ -80,7 +80,7 @@ class PreferenceDatasetConfig:
wait_for_completion=True,
metadata_columns=["prompt", "score_chosen", "score_rejected"],
filetype="parquet",
splits=["train_prefs"],
splits=["train_prefs", "test_prefs"],
),
"allenai/olmo-2-1124-7b-preference-mix": PreferenceDatasetConfig(
hf_dataset_id="allenai/olmo-2-1124-7b-preference-mix",
Expand Down
62 changes: 62 additions & 0 deletions experiments/simple_dpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 The Marin Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

from fray.cluster import ResourceConfig
from levanter.schedule import IntSchedule


@dataclass(frozen=True)
class SimpleDPOConfig:
"""
A simplified configuration for Direct Preference Optimization (DPO).
"""

resources: ResourceConfig

train_batch_size: int | IntSchedule = 128
num_train_steps: int = 10000
learning_rate: float = 1e-6
wandb_project: str | None = None

tokenizer: str | None = None
model_name_or_path: str | None = None
initialize_from_checkpoint_path: str | None = None

reference_model_path: str | None = None
reference_is_hf: bool = True
beta: float = 0.1
validation_split_fraction: float | None = 0.1

train_seq_len: int | None = None
max_seq_len: int = 4096

weight_decay: float = 0.0
warmup: float = 0.03
cooldown: float | None = None
lr_schedule: str = "linear"
min_lr_ratio: float = 0.0
max_grad_norm: float | None = None

steps_per_eval: int = 1000
steps_per_checkpoint: int = 1000
steps_per_hf_export: int = 500
hf_save_dtype: str | None = None

seed: int = 0
initialize_from_hf: bool | None = None

allow_partial_checkpoint: bool = False
int8: bool = False
1 change: 1 addition & 0 deletions lib/haliax/src/haliax/nn/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def init(
@functools.wraps(module)
def fn(*args, **kwargs):
stacked = haliax.vmap(module.init, Block)(*args, **kwargs)
stacked = haliax.auto_sharded(stacked)
return Stacked(stacked, Block, gradient_checkpointing)

return fn
Expand Down
7 changes: 7 additions & 0 deletions lib/haliax/src/haliax/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def _do_device_put(named):
# this happens when we filter out params for things like lora.
# could use eqx.partition to avoid this, but eh
return named
if getattr(named.array, "batch_dim", None) is not None:
# Batched tracers from vmap can't be safely device_put with axis-mapped sharding
# because the leading batch axis isn't represented in the NamedArray axes.
# We'll shard after vmap adds the axis.
return named

pspec = pspec_for(named, mapping)
assert isinstance(pspec, PartitionSpec)
Expand Down Expand Up @@ -290,6 +295,8 @@ def pspec_for(

def partition_spec(node: typing.Any):
if isinstance(node, NamedArray):
if not is_jax_array_like(node.array):
return None
return pspec_for_axis(node.axes, resource_mapping)
elif isinstance(node, eqx.Module):
# handle eqx.Module explicitly so that we can look at axis_names metadata
Expand Down
Loading
Loading