From 3257426950bf10997e0188c39e694519dcc5b9ff Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Sat, 8 Nov 2025 01:34:44 +0000 Subject: [PATCH 01/12] finish config.py --- recipes/lm/poft/README.md | 0 recipes/lm/poft/__init__.py | 0 recipes/lm/poft/__main__.py | 15 ++ recipes/lm/poft/config.py | 141 ++++++++++++++ recipes/lm/poft/dataset.py | 368 ++++++++++++++++++++++++++++++++++++ recipes/lm/poft/recipe.py | 0 recipes/lm/poft/utils.py | 0 7 files changed, 524 insertions(+) create mode 100644 recipes/lm/poft/README.md create mode 100644 recipes/lm/poft/__init__.py create mode 100644 recipes/lm/poft/__main__.py create mode 100644 recipes/lm/poft/config.py create mode 100644 recipes/lm/poft/dataset.py create mode 100644 recipes/lm/poft/recipe.py create mode 100644 recipes/lm/poft/utils.py diff --git a/recipes/lm/poft/README.md b/recipes/lm/poft/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/recipes/lm/poft/__init__.py b/recipes/lm/poft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/recipes/lm/poft/__main__.py b/recipes/lm/poft/__main__.py new file mode 100644 index 000000000..18c651a87 --- /dev/null +++ b/recipes/lm/poft/__main__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.recipe.cli import train_main + +from .recipe import LMDPORecipe + +recipe = LMDPORecipe() + +train_main(recipe) diff --git a/recipes/lm/poft/config.py b/recipes/lm/poft/config.py new file mode 100644 index 000000000..74ce4c407 --- /dev/null +++ b/recipes/lm/poft/config.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from fairseq2.recipe.config import ( + ADAMW_OPTIMIZER, + COSINE_ANNEALING_LR, + AdamWConfig, + CommonSection, + CosineAnnealingLRConfig, + DatasetSection, + GangSection, + LRSchedulerSection, + ModelSection, + OptimizerSection, + RegimeSection, + TokenizerSection, + TorchConfig, + TrainerSection, + ActivationCheckpointingConfig +) + +from .dataset import LM_POFT_DATASET, LMPOFTDatasetConfig + + +@dataclass(kw_only=True) +class LMPOFTDatasetSection(DatasetSection): + path: str | None = None + + source_encode_mode: str = "prompt" + """The encode mode for the prompt, determines what special tokens to add.""" + + target_encode_mode: str = "prompt_response" + """The encode mode for the target, determines what special tokens to add.""" + + mask_source_tokens: bool = True + """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" + + min_seq_len: int = 1 + """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Shorter sequences will be dropped.""" + + max_seq_len: int = 8192 + """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Longer sequences will be dropped.""" + + max_num_tokens: int = 8192 * 2 + """The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch.""" + + batch_size: int | None = None + """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" + + example_shuffle_window: int = 10_000 + """The size of the sliding window for shuffling examples.""" + + batch_shuffle_window: int = 1_000 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + extras: dict[str, object] = field(default_factory=dict) + """The dataset-specific extra options.""" + + chat_mode: bool = False + """If True, dataset jsonl must have 'chat' field with openai-like messages List[Dict] entries""" + + +@dataclass(kw_only=True) +class LMPOFTConfig: + model: ModelSection = field( + default_factory=lambda: ModelSection( + family="llama", + name="llama3_1_8b_instruct", + compile=False, + ) + ) + + dataset: LMPOFTDatasetSection = field( + default_factory=lambda: LMPOFTDatasetSection( + family=LM_POFT_DATASET, + batch_size=16, + config_overrides=LMPOFTDatasetConfig(path="hg://facebook/fairseq2-lm-gsm8k"), + ) + ) + + tokenizer: TokenizerSection = field( + default_factory=lambda: TokenizerSection( + family="llama", + name="llama3_instruct", + ) + ) + + gang: GangSection = field(default_factory=lambda: GangSection()) + + trainer: TrainerSection = field( + default_factory=lambda: TrainerSection( + data_parallelism="fsdp", max_grad_norm=1.0, + activation_checkpointing=ActivationCheckpointingConfig(mode="layerwise"), + ) + ) + + optimizer: OptimizerSection = field( + default_factory=lambda: OptimizerSection( + name=ADAMW_OPTIMIZER, + config=AdamWConfig( + lr=5.5e-06, betas=(0.9, 0.95), weight_decay=0.1, impl="fused" + ), + ) + ) + + lr_scheduler: LRSchedulerSection | None = field( + default_factory=lambda: LRSchedulerSection( + name=COSINE_ANNEALING_LR, config=CosineAnnealingLRConfig(final_lr_scale=0.2) + ) + ) + + regime: RegimeSection = field( + default_factory=lambda: RegimeSection( + num_steps=5000, + validate_every_n_steps=100, + checkpoint_every_n_steps=1000, + keep_last_n_checkpoints=1, + publish_metrics_every_n_steps=10, + export_hugging_face=True, + ) + ) + + # The memory efficient SDPA implementation in PyTorch is numerically not + # stable when used with padded inputs. + common: CommonSection = field( + default_factory=lambda: CommonSection( + torch=TorchConfig(default_sdpa="torch_math") + ) + ) diff --git a/recipes/lm/poft/dataset.py b/recipes/lm/poft/dataset.py new file mode 100644 index 000000000..3924f77c8 --- /dev/null +++ b/recipes/lm/poft/dataset.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import json +from collections.abc import MutableMapping, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Final, cast, final + +import torch + +from fairseq2.assets import get_asset_download_manager +from fairseq2.data import ( + CollateOptionsOverride, + Collater, + DataPipelineBuilder, + SequenceData, + create_bucket_sizes, +) +from fairseq2.data.data_pipeline import DataPipeline, read_sequence +from fairseq2.data.text import read_text +from fairseq2.data.tokenizers import Tokenizer +from fairseq2.data.tokenizers.hg import HuggingFaceTokenEncoder +from fairseq2.datasets import DataPipelineReader, SequenceBatch, SyncMode +from fairseq2.error import NotSupportedError +from fairseq2.gang import Gangs +from fairseq2.utils.uri import Uri + +from .utils import ( + Batching, + DatasetLoadError, + LengthBatching, + StaticBatching, + load_files_and_weights, +) + +LM_POFT_DATASET: Final = "lm_poft" + + +@dataclass(kw_only=True) +class DataReadOptions: + batching: Batching = field(default_factory=lambda: StaticBatching(1)) + """The batching strategy for returned examples.""" + + example_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling examples. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + batch_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling batches. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + drop_remainder: bool = False + """ + If ``True``, drops the last set of batches if they have in total fewer + examples than requested. + """ + + sync_batches: bool = True + """ + If ``True``, ensures that each process in the gang reads the same number of + batches. Typically used when the amount of data to be read can vary per + process (e.g. due to unbalanced sharding or non-static batching) and it is + critical for each process to iterate over the same number of batches (e.g. + during training). + """ + + sync_mode: SyncMode = SyncMode.UNTIL_FIRST + """ + The data synchronization mode among processes in the gang. Only effective if + :attr:`sync_batches` is ``True``. + """ + + max_num_batches: int | None = None + """The maximum number of batches to return.""" + + num_accumulate: int = 1 + """ + The number of batches to accumulate in each iteration. Typically used with + gradient accumulation during training. + """ + + prefetch: int = 1 + """The number of batches to prefetch in background.""" + + npc: int = 10 + """The reference number of parallel calls that data reader can do.""" + + seed: int = 2 + """The seed to initialize the random number generators used internally.""" + + extras: MutableMapping[str, object] = field(default_factory=dict) + """The reader-specific extra options.""" + + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" + + chat_mode: bool = False + + +@final +class LMSFTDataset: + _name: str + _splits: dict[str, tuple[Sequence[Path], Sequence[float]]] + + def __init__( + self, name: str, splits: dict[str, tuple[Sequence[Path], Sequence[float]]] + ) -> None: + """ + :param files: + The instruction files. + :param weights: + The weight of each file in ``files``. + """ + self._name = name + + for split, (files, weights) in splits.items(): + if len(files) != len(weights): + raise ValueError( + f"The lengths of the file and weight lists of the '{split}' split must match, but they are {len(files)} and {len(weights)} instead." + ) + + self._splits = splits + + def _read_jsonl(self, path: Path, tokenizer: Tokenizer) -> DataPipelineBuilder: + lines = [] + + # TODO(balioglu): Do in C++. + with path.open(encoding="utf-8") as fp: + for line in fp: + lines.append(line) + + return read_sequence(lines).map(json.loads) + + def create_reader( + self, + split: str, + tokenizer: Tokenizer, + gangs: Gangs, + min_seq_len: int, + max_seq_len: int, + options: DataReadOptions | None = None, + ) -> DataPipelineReader[SequenceBatch]: + + files_weights = self._splits.get(split) + if files_weights is None: + raise ValueError(f"files_weights for split '{split}' is None") + files, weights = files_weights + + if options is None: + options = DataReadOptions() + + seed = options.seed + + builder = read_sequence(files) + + def read_file(file: Path) -> DataPipeline: + return read_text(file).map(json.loads, num_parallel_calls=1).and_return() + + builder.yield_from(read_file) + + # Shuffle files. Must be consistent across all processes. + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed=seed) + + seed += 1 + + # Shard. + builder.shard(gangs.dp.rank, gangs.dp.size, allow_uneven=True) + + seed += gangs.dp.rank + + if options.chat_mode is True: + # not passing any encoding modes here, because we use apply_chat_template here + encoder = tokenizer.create_encoder() + if not isinstance(encoder, HuggingFaceTokenEncoder): + raise RuntimeError( + "Huggingface tokenizer must be used when chat_mode is True" + ) + else: + + def encoding_chat(example: dict[str, Any]) -> dict[str, Any]: + id_ = example.get("id", None) + chat = example.get("chat", None) + + if not chat: + chat = [ + {"role": "user", "content": example.get("src")}, + {"role": "assistant", "content": example.get("tgt")}, + ] + + encoded_output = encoder.apply_chat_template( + chat, + return_dict=True, + return_assistant_tokens_mask=True, + return_tensors="pt", + ) + + indices = encoded_output["input_ids"][0] + target_mask = encoded_output["assistant_masks"][0].bool() + + return {"id": id_, "indices": indices, "target_mask": target_mask} + + builder.map(encoding_chat) + + else: + + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) + + builder.map(source_encoder, selector="src") + builder.map(target_encoder, selector="tgt") + + def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: + id_ = example.get("id") + + source_indices = example["src"] + target_indices = example["tgt"] + + indices = torch.cat([source_indices, target_indices]) + + target_mask = torch.arange(len(indices)) >= len(source_indices) + + return {"id": id_, "indices": indices, "target_mask": target_mask} + + builder.map(cat_source_and_target) + + batching = options.batching + + if isinstance(batching, LengthBatching): + bucket_sizes = create_bucket_sizes( + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, + ) + + # Bucket by the sequence length. + builder.bucket_by_length( + bucket_sizes, + selector="indices", + min_data_len=min_seq_len, + skip_above_max_examples=True, + drop_remainder=options.drop_remainder, + ) + elif isinstance(batching, StaticBatching): + # Filter out long examples. + def skip(example: dict[str, Any]) -> bool: + seq_len = len(example["indices"]) + + return seq_len >= min_seq_len and seq_len <= max_seq_len + + builder.filter(skip) + + # Bucket `batch_size` examples. + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") + + # Shuffle buckets. + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) + + seed += 1 + + # Collate bucketed examples into a batch. + target_mask_collate_opts = CollateOptionsOverride( + "target_mask", pad_value=False + ) + + if tokenizer.vocab_info.pad_idx is None: + raise RuntimeError( + "LMSFTDataset requires pad token to work for batching purposes, check your tokenizer config." + ) + + collater = Collater( + pad_value=tokenizer.vocab_info.pad_idx, overrides=[target_mask_collate_opts] + ) + + builder.map(collater, num_parallel_calls=options.npc) + + # Return only the first `max_num_batches`. + if options.max_num_batches is not None: + builder.take(options.max_num_batches) + + # Prefetch `prefetch` batches in background. + builder.prefetch(options.prefetch) + + # Wrap examples with `SequenceBatch`. + def to_batch(example: dict[str, Any]) -> SequenceBatch: + indices = cast(SequenceData, example["indices"]) + + seqs, seq_lens = indices["seqs"], indices["seq_lens"] + target_mask = example["target_mask"]["seqs"] + + return SequenceBatch( + seqs, seq_lens, target_mask=target_mask, example=example + ) + + pipeline = builder.map(to_batch).and_return() + + return DataPipelineReader[SequenceBatch]( + pipeline, + gangs, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync=options.sync_batches, + sync_mode=options.sync_mode, + ) + + +@dataclass +class LMSFTDatasetConfig: + path: str | None = None + + +def open_lm_sft_dataset(config: LMSFTDatasetConfig) -> LMSFTDataset: + name = "default" # FIXME + splits: dict[str, tuple[Sequence[Path], Sequence[float]]] = {} + + if config.path is None: + raise ValueError("config.path cannot be None") + + uri = Uri.maybe_parse(config.path) + if uri: + path = get_asset_download_manager().download_dataset(uri, config.path) + else: + path = Path(config.path) + + if path.is_dir(): + try: + child_dirs = [p for p in path.iterdir() if p.is_dir()] + except OSError as ex: + raise DatasetLoadError( + name, f"The files under the '{path}' directory of the '{name}' dataset cannot be retrieved. See the nested exception for details." # fmt: skip + ) from ex + + for child_dir in child_dirs: + files, weights = load_files_and_weights(name, child_dir) + + splits[child_dir.name] = (files, weights) + + if not splits: + files, weights = load_files_and_weights(name, path) + + splits["default"] = (files, weights) + + return LMSFTDataset(name, splits) diff --git a/recipes/lm/poft/recipe.py b/recipes/lm/poft/recipe.py new file mode 100644 index 000000000..e69de29bb diff --git a/recipes/lm/poft/utils.py b/recipes/lm/poft/utils.py new file mode 100644 index 000000000..e69de29bb From d73a2c392aa5c506d3e270e74a1c00d4f5db4e2b Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Mon, 10 Nov 2025 23:30:43 +0000 Subject: [PATCH 02/12] update for latest api --- recipes/lm/{poft => dpo}/README.md | 0 recipes/lm/{poft => dpo}/__init__.py | 0 recipes/lm/{poft => dpo}/__main__.py | 0 recipes/lm/{poft => dpo}/config.py | 121 +++++++++++++-------------- recipes/lm/{poft => dpo}/dataset.py | 2 +- recipes/lm/{poft => dpo}/recipe.py | 0 recipes/lm/{poft => dpo}/utils.py | 0 7 files changed, 61 insertions(+), 62 deletions(-) rename recipes/lm/{poft => dpo}/README.md (100%) rename recipes/lm/{poft => dpo}/__init__.py (100%) rename recipes/lm/{poft => dpo}/__main__.py (100%) rename recipes/lm/{poft => dpo}/config.py (85%) rename recipes/lm/{poft => dpo}/dataset.py (99%) rename recipes/lm/{poft => dpo}/recipe.py (100%) rename recipes/lm/{poft => dpo}/utils.py (100%) diff --git a/recipes/lm/poft/README.md b/recipes/lm/dpo/README.md similarity index 100% rename from recipes/lm/poft/README.md rename to recipes/lm/dpo/README.md diff --git a/recipes/lm/poft/__init__.py b/recipes/lm/dpo/__init__.py similarity index 100% rename from recipes/lm/poft/__init__.py rename to recipes/lm/dpo/__init__.py diff --git a/recipes/lm/poft/__main__.py b/recipes/lm/dpo/__main__.py similarity index 100% rename from recipes/lm/poft/__main__.py rename to recipes/lm/dpo/__main__.py diff --git a/recipes/lm/poft/config.py b/recipes/lm/dpo/config.py similarity index 85% rename from recipes/lm/poft/config.py rename to recipes/lm/dpo/config.py index 74ce4c407..3328085a0 100644 --- a/recipes/lm/poft/config.py +++ b/recipes/lm/dpo/config.py @@ -26,74 +26,26 @@ ActivationCheckpointingConfig ) -from .dataset import LM_POFT_DATASET, LMPOFTDatasetConfig +from .dataset import LM_DPO_DATASET, LMDPODatasetConfig, LMDPODataSource @dataclass(kw_only=True) -class LMPOFTDatasetSection(DatasetSection): - path: str | None = None - - source_encode_mode: str = "prompt" - """The encode mode for the prompt, determines what special tokens to add.""" - - target_encode_mode: str = "prompt_response" - """The encode mode for the target, determines what special tokens to add.""" - - mask_source_tokens: bool = True - """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" - - min_seq_len: int = 1 - """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. - Shorter sequences will be dropped.""" - - max_seq_len: int = 8192 - """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. - Longer sequences will be dropped.""" - - max_num_tokens: int = 8192 * 2 - """The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch.""" - - batch_size: int | None = None - """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" - - example_shuffle_window: int = 10_000 - """The size of the sliding window for shuffling examples.""" - - batch_shuffle_window: int = 1_000 - """The size of the sliding window for shuffling batches.""" - - num_prefetch: int = 4 - """The number of batches to prefetch in background.""" - - extras: dict[str, object] = field(default_factory=dict) - """The dataset-specific extra options.""" - - chat_mode: bool = False - """If True, dataset jsonl must have 'chat' field with openai-like messages List[Dict] entries""" - - -@dataclass(kw_only=True) -class LMPOFTConfig: +class LMDPOConfig: model: ModelSection = field( - default_factory=lambda: ModelSection( - family="llama", - name="llama3_1_8b_instruct", - compile=False, - ) + default_factory=lambda: ModelSection(name="llama3_1_8b_instruct") ) - - dataset: LMPOFTDatasetSection = field( - default_factory=lambda: LMPOFTDatasetSection( - family=LM_POFT_DATASET, - batch_size=16, - config_overrides=LMPOFTDatasetConfig(path="hg://facebook/fairseq2-lm-gsm8k"), - ) + + tokenizer: TokenizerSection = field( + default_factory=lambda: TokenizerSection(name="llama3_instruct") ) + + - tokenizer: TokenizerSection = field( - default_factory=lambda: TokenizerSection( - family="llama", - name="llama3_instruct", + dataset: LMDPODatasetSection = field( + default_factory=lambda: LMDPODatasetSection( + family=LM_DPO_DATASET, + batch_size=16, + config_overrides=LMDPODatasetConfig(path="hg://facebook/fairseq2-lm-gsm8k"), ) ) @@ -139,3 +91,50 @@ class LMPOFTConfig: torch=TorchConfig(default_sdpa="torch_math") ) ) + + dataset: LMDPODatasetSection = field( + default_factory=lambda: LMDPODatasetSection(family=LM_DPO_DATASET), + ) + + +@dataclass(kw_only=True) +class LMDPODatasetSection(DatasetSection): + path: str | None = None + + source_encode_mode: str = "prompt" + """The encode mode for the prompt, determines what special tokens to add.""" + + target_encode_mode: str = "prompt_response" + """The encode mode for the target, determines what special tokens to add.""" + + mask_source_tokens: bool = True + """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" + + min_seq_len: int = 1 + """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Shorter sequences will be dropped.""" + + max_seq_len: int = 8192 + """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Longer sequences will be dropped.""" + + max_num_tokens: int = 8192 * 2 + """The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch.""" + + batch_size: int | None = None + """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" + + example_shuffle_window: int = 10_000 + """The size of the sliding window for shuffling examples.""" + + batch_shuffle_window: int = 1_000 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + extras: dict[str, object] = field(default_factory=dict) + """The dataset-specific extra options.""" + + chat_mode: bool = False + """If True, dataset jsonl must have 'chat' field with openai-like messages List[Dict] entries""" diff --git a/recipes/lm/poft/dataset.py b/recipes/lm/dpo/dataset.py similarity index 99% rename from recipes/lm/poft/dataset.py rename to recipes/lm/dpo/dataset.py index 3924f77c8..73e648e0e 100644 --- a/recipes/lm/poft/dataset.py +++ b/recipes/lm/dpo/dataset.py @@ -39,7 +39,7 @@ load_files_and_weights, ) -LM_POFT_DATASET: Final = "lm_poft" +LM_DPO_DATASET: Final = "lm_dpo" @dataclass(kw_only=True) diff --git a/recipes/lm/poft/recipe.py b/recipes/lm/dpo/recipe.py similarity index 100% rename from recipes/lm/poft/recipe.py rename to recipes/lm/dpo/recipe.py diff --git a/recipes/lm/poft/utils.py b/recipes/lm/dpo/utils.py similarity index 100% rename from recipes/lm/poft/utils.py rename to recipes/lm/dpo/utils.py From 60840d620c21a651806c23c5f93c2bdf3f63eafa Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Mon, 10 Nov 2025 23:38:25 +0000 Subject: [PATCH 03/12] black config --- recipes/lm/dpo/config.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/recipes/lm/dpo/config.py b/recipes/lm/dpo/config.py index 3328085a0..476ff0536 100644 --- a/recipes/lm/dpo/config.py +++ b/recipes/lm/dpo/config.py @@ -23,10 +23,10 @@ TokenizerSection, TorchConfig, TrainerSection, - ActivationCheckpointingConfig + ActivationCheckpointingConfig, ) -from .dataset import LM_DPO_DATASET, LMDPODatasetConfig, LMDPODataSource +from .dataset import LM_DPO_DATASET @dataclass(kw_only=True) @@ -34,26 +34,21 @@ class LMDPOConfig: model: ModelSection = field( default_factory=lambda: ModelSection(name="llama3_1_8b_instruct") ) - + tokenizer: TokenizerSection = field( default_factory=lambda: TokenizerSection(name="llama3_instruct") ) - - dataset: LMDPODatasetSection = field( - default_factory=lambda: LMDPODatasetSection( - family=LM_DPO_DATASET, - batch_size=16, - config_overrides=LMDPODatasetConfig(path="hg://facebook/fairseq2-lm-gsm8k"), - ) + default_factory=lambda: LMDPODatasetSection(family=LM_DPO_DATASET), ) gang: GangSection = field(default_factory=lambda: GangSection()) trainer: TrainerSection = field( default_factory=lambda: TrainerSection( - data_parallelism="fsdp", max_grad_norm=1.0, + data_parallelism="fsdp", + max_grad_norm=1.0, activation_checkpointing=ActivationCheckpointingConfig(mode="layerwise"), ) ) @@ -91,10 +86,6 @@ class LMDPOConfig: torch=TorchConfig(default_sdpa="torch_math") ) ) - - dataset: LMDPODatasetSection = field( - default_factory=lambda: LMDPODatasetSection(family=LM_DPO_DATASET), - ) @dataclass(kw_only=True) From 734d96cdbeda18afc94b5d255d4384fc9a90c9c9 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 00:23:42 +0000 Subject: [PATCH 04/12] rebase DataReadOptions to common so dpo can also use --- recipes/lm/common.py | 82 ++++++++++++++++++++++++++++++++++++++ recipes/lm/sft/dataset.py | 83 +++------------------------------------ recipes/lm/sft/recipe.py | 6 +-- 3 files changed, 90 insertions(+), 81 deletions(-) diff --git a/recipes/lm/common.py b/recipes/lm/common.py index d0198ffc1..f48292e75 100644 --- a/recipes/lm/common.py +++ b/recipes/lm/common.py @@ -6,6 +6,10 @@ from __future__ import annotations +from dataclasses import dataclass, field +from typing import TypeAlias +from fairseq2.datasets import SyncMode + from fairseq2.logging import log from fairseq2.models.clm import CausalLM from fairseq2.nn import Embedding @@ -48,3 +52,81 @@ def _maybe_get_embed(model: CausalLM) -> Embedding | None: return None return embed + + +@dataclass +class StaticBatching: + """Specifies batching where each batch has the same number of examples.""" + + batch_size: int + """The number of examples in each batch.""" + + +@dataclass +class LengthBatching: + """Specifies batching where each batch has a maximum number of elements.""" + + max_num_elements: int + """The maximum number of elements (e.g. tokens) in each batch.""" + + +Batching: TypeAlias = StaticBatching | LengthBatching + + +@dataclass(kw_only=True) +class DataReadOptions: + batching: Batching = field(default_factory=lambda: StaticBatching(1)) + """The batching strategy for returned examples.""" + + example_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling examples. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + batch_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling batches. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + drop_remainder: bool = False + """ + If ``True``, drops the last set of batches if they have in total fewer + examples than requested. + """ + + sync_batches: bool = True + """ + If ``True``, ensures that each process in the gang reads the same number of + batches. Typically used when the amount of data to be read can vary per + process (e.g. due to unbalanced sharding or non-static batching) and it is + critical for each process to iterate over the same number of batches (e.g. + during training). + """ + + sync_mode: SyncMode = SyncMode.UNTIL_FIRST + """ + The data synchronization mode among processes in the gang. Only effective if + :attr:`sync_batches` is ``True``. + """ + + max_num_batches: int | None = None + """The maximum number of batches to return.""" + + num_accumulate: int = 1 + """ + The number of batches to accumulate in each iteration. Typically used with + gradient accumulation during training. + """ + + prefetch: int = 1 + """The number of batches to prefetch in background.""" + + npc: int = 10 + """The reference number of parallel calls that data reader can do.""" + + seed: int = 2 + """The seed to initialize the random number generators used internally.""" diff --git a/recipes/lm/sft/dataset.py b/recipes/lm/sft/dataset.py index 29e527289..b7ebc965c 100644 --- a/recipes/lm/sft/dataset.py +++ b/recipes/lm/sft/dataset.py @@ -29,86 +29,13 @@ from fairseq2.gang import Gangs from fairseq2.utils.uri import Uri -LM_SFT_DATASET: Final = "lm_sft" - - -@dataclass -class StaticBatching: - """Specifies batching where each batch has the same number of examples.""" - - batch_size: int - """The number of examples in each batch.""" - - -@dataclass -class LengthBatching: - """Specifies batching where each batch has a maximum number of elements.""" - - max_num_elements: int - """The maximum number of elements (e.g. tokens) in each batch.""" +from ..common import DataReadOptions - -Batching: TypeAlias = StaticBatching | LengthBatching +LM_SFT_DATASET: Final = "lm_sft" @dataclass(kw_only=True) -class DataReadOptions: - batching: Batching = field(default_factory=lambda: StaticBatching(1)) - """The batching strategy for returned examples.""" - - example_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ - - batch_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ - - drop_remainder: bool = False - """ - If ``True``, drops the last set of batches if they have in total fewer - examples than requested. - """ - - sync_batches: bool = True - """ - If ``True``, ensures that each process in the gang reads the same number of - batches. Typically used when the amount of data to be read can vary per - process (e.g. due to unbalanced sharding or non-static batching) and it is - critical for each process to iterate over the same number of batches (e.g. - during training). - """ - - sync_mode: SyncMode = SyncMode.UNTIL_FIRST - """ - The data synchronization mode among processes in the gang. Only effective if - :attr:`sync_batches` is ``True``. - """ - - max_num_batches: int | None = None - """The maximum number of batches to return.""" - - num_accumulate: int = 1 - """ - The number of batches to accumulate in each iteration. Typically used with - gradient accumulation during training. - """ - - prefetch: int = 1 - """The number of batches to prefetch in background.""" - - npc: int = 10 - """The reference number of parallel calls that data reader can do.""" - - seed: int = 2 - """The seed to initialize the random number generators used internally.""" - +class LMSFTDataReadOptions(DataReadOptions): sample: bool = False """ If ``True``, instruction sources (e.g. JSONL files) will be sampled in @@ -173,10 +100,10 @@ def create_reader( gangs: Gangs, min_seq_len: int, max_seq_len: int, - options: DataReadOptions | None = None, + options: LMSFTDataReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: if options is None: - options = DataReadOptions() + options = LMSFTDataReadOptions() sources = self._sources[split] diff --git a/recipes/lm/sft/recipe.py b/recipes/lm/sft/recipe.py index 1e9781251..db92dc17c 100644 --- a/recipes/lm/sft/recipe.py +++ b/recipes/lm/sft/recipe.py @@ -33,7 +33,7 @@ from .dataset import ( LM_SFT_DATASET, Batching, - DataReadOptions, + LMSFTDataReadOptions, LengthBatching, LMSFTDataset, LMSFTDatasetConfig, @@ -73,7 +73,7 @@ def create_task(self, context: RecipeContext) -> Task: else: batching = LengthBatching(config.dataset.max_num_tokens) - read_options = DataReadOptions( + read_options = LMSFTDataReadOptions( batching=batching, example_shuffle_window=config.dataset.example_shuffle_window, batch_shuffle_window=config.dataset.batch_shuffle_window, @@ -107,7 +107,7 @@ def create_task(self, context: RecipeContext) -> Task: valid_batching = LengthBatching(max_num_tokens) - read_options = DataReadOptions( + read_options = LMSFTDataReadOptions( batching=valid_batching, prefetch=config.dataset.prefetch, source_encode_mode=config.dataset.source_encode_mode, From 17c2b80bece958d3cc06514a599780c57256abaa Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 00:31:59 +0000 Subject: [PATCH 05/12] rebase DataReadOptions cont' --- recipes/lm/sft/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/lm/sft/dataset.py b/recipes/lm/sft/dataset.py index b7ebc965c..065070d4b 100644 --- a/recipes/lm/sft/dataset.py +++ b/recipes/lm/sft/dataset.py @@ -29,7 +29,7 @@ from fairseq2.gang import Gangs from fairseq2.utils.uri import Uri -from ..common import DataReadOptions +from ..common import DataReadOptions, LengthBatching, StaticBatching LM_SFT_DATASET: Final = "lm_sft" From 87fffee2b528ec275c93595d58949124aba2b6a2 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 01:24:36 +0000 Subject: [PATCH 06/12] dpo dataset draft --- recipes/lm/dpo/dataset.py | 439 +++++++++++++++++++++----------------- 1 file changed, 244 insertions(+), 195 deletions(-) diff --git a/recipes/lm/dpo/dataset.py b/recipes/lm/dpo/dataset.py index 73e648e0e..09353bc0c 100644 --- a/recipes/lm/dpo/dataset.py +++ b/recipes/lm/dpo/dataset.py @@ -7,18 +7,17 @@ from __future__ import annotations import json -from collections.abc import MutableMapping, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Final, cast, final +from typing import Any, Final, TypeAlias, cast import torch +from typing_extensions import override from fairseq2.assets import get_asset_download_manager from fairseq2.data import ( CollateOptionsOverride, Collater, - DataPipelineBuilder, SequenceData, create_bucket_sizes, ) @@ -26,89 +25,63 @@ from fairseq2.data.text import read_text from fairseq2.data.tokenizers import Tokenizer from fairseq2.data.tokenizers.hg import HuggingFaceTokenEncoder -from fairseq2.datasets import DataPipelineReader, SequenceBatch, SyncMode -from fairseq2.error import NotSupportedError +from fairseq2.datasets import DataPipelineReader, SequenceBatch +from fairseq2.device import Device, SupportsDeviceTransfer +from fairseq2.error import NotSupportedError, raise_operational_system_error from fairseq2.gang import Gangs from fairseq2.utils.uri import Uri -from .utils import ( - Batching, - DatasetLoadError, - LengthBatching, - StaticBatching, - load_files_and_weights, -) - -LM_DPO_DATASET: Final = "lm_dpo" +from ..common import DataReadOptions, LengthBatching, StaticBatching -@dataclass(kw_only=True) -class DataReadOptions: - batching: Batching = field(default_factory=lambda: StaticBatching(1)) - """The batching strategy for returned examples.""" - - example_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ +LM_DPO_DATASET: Final = "lm_dpo" - batch_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ - drop_remainder: bool = False - """ - If ``True``, drops the last set of batches if they have in total fewer - examples than requested. - """ - - sync_batches: bool = True - """ - If ``True``, ensures that each process in the gang reads the same number of - batches. Typically used when the amount of data to be read can vary per - process (e.g. due to unbalanced sharding or non-static batching) and it is - critical for each process to iterate over the same number of batches (e.g. - during training). - """ +@dataclass +class PreferenceBatch(SupportsDeviceTransfer): + """Represents a preference optimization dataset batch.""" - sync_mode: SyncMode = SyncMode.UNTIL_FIRST - """ - The data synchronization mode among processes in the gang. Only effective if - :attr:`sync_batches` is ``True``. - """ + chosen: SequenceBatch + rejected: SequenceBatch + reference_score_chosen: torch.Tensor | None + reference_score_rejected: torch.Tensor | None - max_num_batches: int | None = None - """The maximum number of batches to return.""" + @property + def batch_size(self) -> int: + """The size of the batch dimension.""" + return self.chosen.batch_size - num_accumulate: int = 1 - """ - The number of batches to accumulate in each iteration. Typically used with - gradient accumulation during training. - """ + @override + def to(self, device: Device, *, non_blocking: bool = False) -> None: + self.chosen.to(device, non_blocking=non_blocking) - prefetch: int = 1 - """The number of batches to prefetch in background.""" + self.rejected.to(device, non_blocking=non_blocking) - npc: int = 10 - """The reference number of parallel calls that data reader can do.""" + if self.reference_score_chosen is not None: + self.reference_score_chosen = self.reference_score_chosen.to( + device, non_blocking=non_blocking + ) - seed: int = 2 - """The seed to initialize the random number generators used internally.""" + if self.reference_score_rejected is not None: + self.reference_score_rejected = self.reference_score_rejected.to( + device, non_blocking=non_blocking + ) - extras: MutableMapping[str, object] = field(default_factory=dict) - """The reader-specific extra options.""" +@dataclass(kw_only=True) +class LMDPODataReadOptions(DataReadOptions): sample: bool = False """ If ``True``, instruction sources (e.g. JSONL files) will be sampled in proportion to their weights. """ + mask_source_tokens: bool = True + """ + If ``False``, calculates loss on the source tokens (prompt) as well as the + target tokens. + """ + source_encode_mode: str = "prompt" """The tokenizer mode to encode the source text.""" @@ -118,39 +91,47 @@ class DataReadOptions: chat_mode: bool = False -@final -class LMSFTDataset: - _name: str - _splits: dict[str, tuple[Sequence[Path], Sequence[float]]] - - def __init__( - self, name: str, splits: dict[str, tuple[Sequence[Path], Sequence[float]]] - ) -> None: - """ - :param files: - The instruction files. - :param weights: - The weight of each file in ``files``. - """ - self._name = name - - for split, (files, weights) in splits.items(): - if len(files) != len(weights): - raise ValueError( - f"The lengths of the file and weight lists of the '{split}' split must match, but they are {len(files)} and {len(weights)} instead." - ) +class LMDPODataset: + def __init__(self, sources: dict[str, list[LMDPODataSource]]) -> None: + self._sources = sources - self._splits = splits + def _create_path_reader( + self, path: str, split: str | None, gangs: Gangs, shuffle_window: int, seed: int + ) -> DataPipeline: + download_manager = get_asset_download_manager() - def _read_jsonl(self, path: Path, tokenizer: Tokenizer) -> DataPipelineBuilder: - lines = [] + uri = Uri.maybe_parse(path) + if uri: + local_path = download_manager.download_dataset(uri) + else: + local_path = Path(path) + + if split: + local_path = local_path.joinpath(split) - # TODO(balioglu): Do in C++. - with path.open(encoding="utf-8") as fp: - for line in fp: - lines.append(line) + if not local_path.is_dir(): + files = [local_path] + else: + try: + files = [f for f in local_path.glob("**/*.jsonl") if not f.is_dir()] + except OSError as ex: + raise_operational_system_error(ex) - return read_sequence(lines).map(json.loads) + files.sort() + + builder = read_sequence(files) + + def read_file(file: Path) -> DataPipeline: + return read_text(file).map(json.loads).and_return() + + builder.yield_from(read_file) + + if shuffle_window != 1: + builder.shuffle(shuffle_window, seed=seed) + + builder.shard(gangs.dp.rank, gangs.dp.size, allow_uneven=True) + + return builder.and_return() def create_reader( self, @@ -159,37 +140,34 @@ def create_reader( gangs: Gangs, min_seq_len: int, max_seq_len: int, - options: DataReadOptions | None = None, - ) -> DataPipelineReader[SequenceBatch]: - - files_weights = self._splits.get(split) - if files_weights is None: - raise ValueError(f"files_weights for split '{split}' is None") - files, weights = files_weights - + options: LMDPODataReadOptions | None = None, + ) -> DataPipelineReader[PreferenceBatch]: if options is None: - options = DataReadOptions() + options = LMDPODataReadOptions() + + sources = self._sources[split] seed = options.seed - builder = read_sequence(files) + pipelines = [] - def read_file(file: Path) -> DataPipeline: - return read_text(file).map(json.loads, num_parallel_calls=1).and_return() + weights = [] - builder.yield_from(read_file) + for source in sources: + pipeline = self._create_path_reader( + source.path, source.split, gangs, options.example_shuffle_window, seed + ) - # Shuffle files. Must be consistent across all processes. - if options.example_shuffle_window != 1: - builder.shuffle(options.example_shuffle_window, seed=seed) + seed += 1 - seed += 1 + pipelines.append(pipeline) - # Shard. - builder.shard(gangs.dp.rank, gangs.dp.size, allow_uneven=True) + weights.append(source.weight) seed += gangs.dp.rank + builder = DataPipeline.sample(pipelines, weights, seed) + if options.chat_mode is True: # not passing any encoding modes here, because we use apply_chat_template here encoder = tokenizer.create_encoder() @@ -200,49 +178,107 @@ def read_file(file: Path) -> DataPipeline: else: def encoding_chat(example: dict[str, Any]) -> dict[str, Any]: - id_ = example.get("id", None) - chat = example.get("chat", None) - - if not chat: - chat = [ - {"role": "user", "content": example.get("src")}, - {"role": "assistant", "content": example.get("tgt")}, - ] + id_ = example.get("id") + chat_chosen = example.get("chat_chosen") + chat_rejected = example.get("chat_rejected") - encoded_output = encoder.apply_chat_template( - chat, + encoded_output_chosen = encoder.apply_chat_template( + chat_chosen, return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt", ) - indices = encoded_output["input_ids"][0] - target_mask = encoded_output["assistant_masks"][0].bool() + indices_chosen = encoded_output_chosen["input_ids"][0] + target_mask_chosen = encoded_output_chosen["assistant_masks"][ + 0 + ].bool() - return {"id": id_, "indices": indices, "target_mask": target_mask} + encoded_output_rejected = encoder.apply_chat_template( + chat_rejected, + return_dict=True, + return_assistant_tokens_mask=True, + return_tensors="pt", + ) + + indices_rejected = encoded_output_rejected["input_ids"][0] + target_mask_rejected = encoded_output_rejected["assistant_masks"][0].bool() + + breakpoint() + if not options.mask_source_tokens: + # no source masking i.e. mask has all 1s + target_mask_chosen._fill(True) + target_mask_rejected._fill(True) + + total_tokens = len(indices_chosen) + len(indices_rejected) + + return { + "id": id_, + "indices_chosen": indices_chosen, + "indices_rejected": indices_rejected, + "reference_score_chosen": example.get( + "reference_score_chosen", None + ), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), + "target_mask_chosen": target_mask_chosen, + "target_mask_rejected": target_mask_rejected, + "total_tokens": total_tokens, + } builder.map(encoding_chat) else: - # Encode source and target texts. source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) builder.map(source_encoder, selector="src") - builder.map(target_encoder, selector="tgt") - + builder.map(target_encoder, selector="tgt_chosen") + builder.map(target_encoder, selector="tgt_rejected") + def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: - id_ = example.get("id") + id_ = example.get("id", None) source_indices = example["src"] - target_indices = example["tgt"] + target_indices_chosen = example["tgt_chosen"] + target_indices_rejected = example["tgt_rejected"] - indices = torch.cat([source_indices, target_indices]) + indices_chosen = torch.cat([source_indices, target_indices_chosen]) + indices_rejected = torch.cat([source_indices, target_indices_rejected]) - target_mask = torch.arange(len(indices)) >= len(source_indices) + if options.mask_source_tokens: + source_len = len(source_indices) + target_mask_chosen = torch.arange(len(indices_chosen)) >= source_len + target_mask_rejected = ( + torch.arange(len(indices_rejected)) >= source_len + ) + else: + target_mask_chosen = torch.full([len(indices_chosen)], True) + target_mask_rejected = torch.full([len(indices_rejected)], True) + + total_tokens = ( + 2 * len(source_indices) + + len(target_indices_chosen) + + len(target_indices_rejected) + ) - return {"id": id_, "indices": indices, "target_mask": target_mask} + return { + "id": id_, + "indices_prompt": source_indices, + "indices_chosen": indices_chosen, + "indices_rejected": indices_rejected, + "reference_score_chosen": example.get( + "reference_score_chosen", None + ), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), + "target_mask_chosen": target_mask_chosen, + "target_mask_rejected": target_mask_rejected, + "total_tokens": total_tokens, + } builder.map(cat_source_and_target) @@ -255,71 +291,107 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: max_num_elements=batching.max_num_elements, ) - # Bucket by the sequence length. + # Bucket by the sequence length builder.bucket_by_length( bucket_sizes, - selector="indices", + selector="total_tokens", min_data_len=min_seq_len, skip_above_max_examples=True, drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): - # Filter out long examples. + # Filter out long examples def skip(example: dict[str, Any]) -> bool: - seq_len = len(example["indices"]) + chosen_len = len(example["indices_chosen"]) + rejected_len = len(example["indices_rejected"]) + + if chosen_len > max_seq_len or rejected_len > max_seq_len: + return False - return seq_len >= min_seq_len and seq_len <= max_seq_len + return chosen_len >= min_seq_len and rejected_len >= min_seq_len builder.filter(skip) - # Bucket `batch_size` examples. + # Bucket `batch_size` examples builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") - # Shuffle buckets. + # Shuffle buckets if options.batch_shuffle_window != 1: builder.shuffle(options.batch_shuffle_window, seed=seed) seed += 1 - + # Collate bucketed examples into a batch. - target_mask_collate_opts = CollateOptionsOverride( - "target_mask", pad_value=False - ) + target_mask_collate_opts = [ + CollateOptionsOverride("target_mask_chosen", pad_value=False), + CollateOptionsOverride("target_mask_rejected", pad_value=False), + ] - if tokenizer.vocab_info.pad_idx is None: - raise RuntimeError( - "LMSFTDataset requires pad token to work for batching purposes, check your tokenizer config." - ) - - collater = Collater( - pad_value=tokenizer.vocab_info.pad_idx, overrides=[target_mask_collate_opts] - ) + collater = Collater(pad_value=0, overrides=target_mask_collate_opts) builder.map(collater, num_parallel_calls=options.npc) - + # Return only the first `max_num_batches`. if options.max_num_batches is not None: builder.take(options.max_num_batches) # Prefetch `prefetch` batches in background. builder.prefetch(options.prefetch) + + # Wrap examples with `PreferenceBatch`. + def to_batch(example: dict[str, Any]) -> PreferenceBatch: + indices_chosen = cast(SequenceData, example["indices_chosen"]) + indices_rejected = cast(SequenceData, example["indices_rejected"]) + + seqs_chosen, seq_chosen_lens = ( + indices_chosen["seqs"], + indices_chosen["seq_lens"], + ) + seqs_rejected, seq_rejected_lens = ( + indices_rejected["seqs"], + indices_rejected["seq_lens"], + ) - # Wrap examples with `SequenceBatch`. - def to_batch(example: dict[str, Any]) -> SequenceBatch: - indices = cast(SequenceData, example["indices"]) + target_mask_chosen = example["target_mask_chosen"]["seqs"] + target_mask_rejected = example["target_mask_rejected"]["seqs"] - seqs, seq_lens = indices["seqs"], indices["seq_lens"] - target_mask = example["target_mask"]["seqs"] + batch_chosen = SequenceBatch( + seqs_chosen, + seq_chosen_lens, + target_mask=target_mask_chosen, + example=example, + ) - return SequenceBatch( - seqs, seq_lens, target_mask=target_mask, example=example + batch_rejected = SequenceBatch( + seqs_rejected, + seq_rejected_lens, + target_mask=target_mask_rejected, + example=example, ) - pipeline = builder.map(to_batch).and_return() + batch_reference_scores_chosen = None + if all(example["reference_score_chosen"]): + batch_reference_scores_chosen = torch.Tensor( + example["reference_score_chosen"] + ) + batch_reference_scores_rejected = None + if all(example["reference_score_rejected"]): + batch_reference_scores_rejected = torch.Tensor( + example["reference_score_rejected"] + ) - return DataPipelineReader[SequenceBatch]( + return PreferenceBatch( + batch_chosen, + batch_rejected, + batch_reference_scores_chosen, + batch_reference_scores_rejected, + ) + + pipeline = builder.map(to_batch).and_return() + + return DataPipelineReader[PreferenceBatch]( pipeline, gangs, num_accumulate=options.num_accumulate, @@ -330,39 +402,16 @@ def to_batch(example: dict[str, Any]) -> SequenceBatch: @dataclass -class LMSFTDatasetConfig: - path: str | None = None - - -def open_lm_sft_dataset(config: LMSFTDatasetConfig) -> LMSFTDataset: - name = "default" # FIXME - splits: dict[str, tuple[Sequence[Path], Sequence[float]]] = {} +class LMDPODataSource: #TODO: to confirm after recipe.py + path: str + split: str | None = None + weight: float = 1.0 - if config.path is None: - raise ValueError("config.path cannot be None") - uri = Uri.maybe_parse(config.path) - if uri: - path = get_asset_download_manager().download_dataset(uri, config.path) - else: - path = Path(config.path) - - if path.is_dir(): - try: - child_dirs = [p for p in path.iterdir() if p.is_dir()] - except OSError as ex: - raise DatasetLoadError( - name, f"The files under the '{path}' directory of the '{name}' dataset cannot be retrieved. See the nested exception for details." # fmt: skip - ) from ex - - for child_dir in child_dirs: - files, weights = load_files_and_weights(name, child_dir) - - splits[child_dir.name] = (files, weights) - - if not splits: - files, weights = load_files_and_weights(name, path) +@dataclass +class LMDPODatasetConfig: #TODO: to confirm after recipe.py + sources: dict[str, list[LMDPODataSource]] = field(default_factory=dict) - splits["default"] = (files, weights) - return LMSFTDataset(name, splits) +def open_lm_dpo_dataset(config: LMDPODatasetConfig) -> LMDPODataset: + return LMDPODataset(config.sources) From 0bb454a5824a96808d174b423c93f97d424e7214 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 01:35:29 +0000 Subject: [PATCH 07/12] resolve error for now --- recipes/lm/dpo/dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/recipes/lm/dpo/dataset.py b/recipes/lm/dpo/dataset.py index 09353bc0c..208531826 100644 --- a/recipes/lm/dpo/dataset.py +++ b/recipes/lm/dpo/dataset.py @@ -14,7 +14,9 @@ import torch from typing_extensions import override -from fairseq2.assets import get_asset_download_manager +# from fairseq2.assets.download_manager import get_asset_download_manager +from fairseq2.assets.download_manager import AssetDownloadManager +from fairseq2.runtime.dependency import get_dependency_resolver from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -98,7 +100,8 @@ def __init__(self, sources: dict[str, list[LMDPODataSource]]) -> None: def _create_path_reader( self, path: str, split: str | None, gangs: Gangs, shuffle_window: int, seed: int ) -> DataPipeline: - download_manager = get_asset_download_manager() + # download_manager = get_asset_download_manager() + download_manager = get_dependency_resolver().resolve(AssetDownloadManager) uri = Uri.maybe_parse(path) if uri: From f7061760fe41713328713500d2aba62770c91c28 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 01:37:05 +0000 Subject: [PATCH 08/12] dpo main draft --- recipes/lm/dpo/__main__.py | 4 ++-- recipes/lm/dpo/utils.py | 0 2 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 recipes/lm/dpo/utils.py diff --git a/recipes/lm/dpo/__main__.py b/recipes/lm/dpo/__main__.py index 18c651a87..b545194de 100644 --- a/recipes/lm/dpo/__main__.py +++ b/recipes/lm/dpo/__main__.py @@ -6,10 +6,10 @@ from __future__ import annotations -from fairseq2.recipe.cli import train_main +from fairseq2.recipe.cli import main from .recipe import LMDPORecipe recipe = LMDPORecipe() -train_main(recipe) +main(recipe) diff --git a/recipes/lm/dpo/utils.py b/recipes/lm/dpo/utils.py deleted file mode 100644 index e69de29bb..000000000 From a8525a9839876dfce76f82ea8ca0ae43778d58fb Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Tue, 11 Nov 2025 06:17:05 +0000 Subject: [PATCH 09/12] fix minor bug in dataset --- recipes/lm/dpo/dataset.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/recipes/lm/dpo/dataset.py b/recipes/lm/dpo/dataset.py index 208531826..b904ba888 100644 --- a/recipes/lm/dpo/dataset.py +++ b/recipes/lm/dpo/dataset.py @@ -98,7 +98,7 @@ def __init__(self, sources: dict[str, list[LMDPODataSource]]) -> None: self._sources = sources def _create_path_reader( - self, path: str, split: str | None, gangs: Gangs, shuffle_window: int, seed: int + self, path: str, gangs: Gangs, shuffle_window: int, seed: int ) -> DataPipeline: # download_manager = get_asset_download_manager() download_manager = get_dependency_resolver().resolve(AssetDownloadManager) @@ -109,9 +109,6 @@ def _create_path_reader( else: local_path = Path(path) - if split: - local_path = local_path.joinpath(split) - if not local_path.is_dir(): files = [local_path] else: @@ -138,7 +135,6 @@ def read_file(file: Path) -> DataPipeline: def create_reader( self, - split: str, tokenizer: Tokenizer, gangs: Gangs, min_seq_len: int, @@ -148,17 +144,15 @@ def create_reader( if options is None: options = LMDPODataReadOptions() - sources = self._sources[split] - seed = options.seed pipelines = [] weights = [] - for source in sources: + for source in self._sources: pipeline = self._create_path_reader( - source.path, source.split, gangs, options.example_shuffle_window, seed + source.path, gangs, options.example_shuffle_window, seed ) seed += 1 From 8543f57922d53b9e711d080f26819246c2e3c717 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Thu, 13 Nov 2025 19:06:03 +0000 Subject: [PATCH 10/12] rebase PreferenceBatch --- recipes/lm/dpo/dataset.py | 34 +--------------------------------- src/fairseq2/datasets/batch.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/recipes/lm/dpo/dataset.py b/recipes/lm/dpo/dataset.py index b904ba888..2d2be83a0 100644 --- a/recipes/lm/dpo/dataset.py +++ b/recipes/lm/dpo/dataset.py @@ -27,8 +27,7 @@ from fairseq2.data.text import read_text from fairseq2.data.tokenizers import Tokenizer from fairseq2.data.tokenizers.hg import HuggingFaceTokenEncoder -from fairseq2.datasets import DataPipelineReader, SequenceBatch -from fairseq2.device import Device, SupportsDeviceTransfer +from fairseq2.datasets import DataPipelineReader, SequenceBatch, PreferenceBatch from fairseq2.error import NotSupportedError, raise_operational_system_error from fairseq2.gang import Gangs from fairseq2.utils.uri import Uri @@ -39,37 +38,6 @@ LM_DPO_DATASET: Final = "lm_dpo" -@dataclass -class PreferenceBatch(SupportsDeviceTransfer): - """Represents a preference optimization dataset batch.""" - - chosen: SequenceBatch - rejected: SequenceBatch - reference_score_chosen: torch.Tensor | None - reference_score_rejected: torch.Tensor | None - - @property - def batch_size(self) -> int: - """The size of the batch dimension.""" - return self.chosen.batch_size - - @override - def to(self, device: Device, *, non_blocking: bool = False) -> None: - self.chosen.to(device, non_blocking=non_blocking) - - self.rejected.to(device, non_blocking=non_blocking) - - if self.reference_score_chosen is not None: - self.reference_score_chosen = self.reference_score_chosen.to( - device, non_blocking=non_blocking - ) - - if self.reference_score_rejected is not None: - self.reference_score_rejected = self.reference_score_rejected.to( - device, non_blocking=non_blocking - ) - - @dataclass(kw_only=True) class LMDPODataReadOptions(DataReadOptions): sample: bool = False diff --git a/src/fairseq2/datasets/batch.py b/src/fairseq2/datasets/batch.py index 28618fe2c..057dd1369 100644 --- a/src/fairseq2/datasets/batch.py +++ b/src/fairseq2/datasets/batch.py @@ -9,6 +9,7 @@ from collections.abc import Sequence from typing import final +import torch from torch import Tensor from typing_extensions import override @@ -662,3 +663,34 @@ def __repr__(self) -> str: ) return f"Seq2SeqBatch({s})" + + +@final +class PreferenceBatch(SupportsDeviceTransfer): + """Represents a preference optimization dataset batch.""" + + chosen: SequenceBatch + rejected: SequenceBatch + reference_score_chosen: torch.Tensor | None + reference_score_rejected: torch.Tensor | None + + @property + def batch_size(self) -> int: + """The size of the batch dimension.""" + return self.chosen.batch_size + + @override + def to(self, device: Device, *, non_blocking: bool = False) -> None: + self.chosen.to(device, non_blocking=non_blocking) + + self.rejected.to(device, non_blocking=non_blocking) + + if self.reference_score_chosen is not None: + self.reference_score_chosen = self.reference_score_chosen.to( + device, non_blocking=non_blocking + ) + + if self.reference_score_rejected is not None: + self.reference_score_rejected = self.reference_score_rejected.to( + device, non_blocking=non_blocking + ) From 4ac5b31796f5381e5d3c51f923f44f7c8b891db9 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Thu, 13 Nov 2025 23:35:15 +0000 Subject: [PATCH 11/12] add _gather_lprobs_avg --- recipes/lm/common.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/recipes/lm/common.py b/recipes/lm/common.py index f48292e75..bff6a2f3c 100644 --- a/recipes/lm/common.py +++ b/recipes/lm/common.py @@ -8,7 +8,11 @@ from dataclasses import dataclass, field from typing import TypeAlias -from fairseq2.datasets import SyncMode +from fairseq2.datasets import SyncMode, SequenceBatch + + +import torch +from torch import Tensor from fairseq2.logging import log from fairseq2.models.clm import CausalLM @@ -54,6 +58,17 @@ def _maybe_get_embed(model: CausalLM) -> Embedding | None: return embed +def _gather_lprobs_avg(logits: Tensor, target: SequenceBatch) -> tuple[Tensor, Tensor]: + assert target.target_mask is not None + logprobs = torch.log_softmax(logits, dim=-1) + per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) + total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] + assert target.target_mask is not None + average_logps = total_logps / target.target_mask.sum(-1) + + return total_logps, average_logps + + @dataclass class StaticBatching: """Specifies batching where each batch has the same number of examples.""" From 1037d6eeecd5e12c496f15146c56a4f241e77057 Mon Sep 17 00:00:00 2001 From: ellenxtan Date: Thu, 13 Nov 2025 23:35:44 +0000 Subject: [PATCH 12/12] add dpo recipe draft --- recipes/lm/dpo/config.py | 21 +++ recipes/lm/dpo/recipe.py | 268 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+) diff --git a/recipes/lm/dpo/config.py b/recipes/lm/dpo/config.py index 476ff0536..8c03a6f5b 100644 --- a/recipes/lm/dpo/config.py +++ b/recipes/lm/dpo/config.py @@ -20,6 +20,7 @@ ModelSection, OptimizerSection, RegimeSection, + ReferenceModelSection, TokenizerSection, TorchConfig, TrainerSection, @@ -87,6 +88,26 @@ class LMDPOConfig: ) ) + # Loss configuration + reference_model: ReferenceModelSection | None = field( + default_factory=lambda: ReferenceModelSection(name="llama3_1_8b_instruct") + ) + """ + The reference model. If ``None``, the recipe expects to get reference + log-probabilities for chosen and rejected targets as float values in the + data example (fields `reference_score_rejected` and `reference_score_chosen`). + """ + + beta: float = 0.1 + """The coefficient of regularization towards the reference model.""" + + nll_scale: float = 0.0 + """The coefficient of NLL loss added to the DPO loss.""" + + length_normalization: bool = False + """Use length normalized DPO, which uses the average log probability of a sequence as the implicit reward.""" + + @dataclass(kw_only=True) class LMDPODatasetSection(DatasetSection): diff --git a/recipes/lm/dpo/recipe.py b/recipes/lm/dpo/recipe.py index e69de29bb..24d280e0b 100644 --- a/recipes/lm/dpo/recipe.py +++ b/recipes/lm/dpo/recipe.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor +from torch.nn import Module +from typing_extensions import override + +from fairseq2.composition import register_dataset_family +from fairseq2.datasets import SequenceBatch +from fairseq2.metrics import MetricBag +from fairseq2.metrics.common import ( + add_nll_loss_metric, + add_seq_batch_metrics, + add_dpo_loss_metric, + add_sequence_length_metrics, + add_logps_metrics, + update_nll_loss_metric, + update_seq_batch_metrics, + update_dpo_loss_metric, + update_sequence_length_metrics, + update_logps_metrics, +) +from fairseq2.models.clm import CausalLM +from fairseq2.recipe.base import Recipe, RecipeContext +from fairseq2.recipe.trainer import TrainUnit +from fairseq2.runtime.dependency import DependencyContainer +from fairseq2.task import Task + +from ..common import check_model_vocabulary, Batching, LengthBatching, StaticBatching, _gather_lprobs_avg +from .config import LMDPOConfig +from .dataset import ( + LM_DPO_DATASET, + LMDPODataReadOptions, + LMDPODataset, + LMDPODatasetConfig, + open_lm_dpo_dataset, + PreferenceBatch, +) + + +class LMDPORecipe(Recipe): + @override + def register(self, container: DependencyContainer) -> None: + register_dataset_family( + container, + LM_DPO_DATASET, + LMDPODataset, + LMDPODatasetConfig, + opener=open_lm_dpo_dataset, + ) + + @override + def create_task(self, context: RecipeContext) -> Task: + config = context.get_config_as(LMDPOConfig) + + check_model_vocabulary(context) + + dp_model = context.get_data_parallel_model() + + # Get reference model if configured + reference_model = None + if config.reference_model is not None: + reference_model = context.bootstrap_reference_model("reference_model") + + unit = LMDPOUnit( + dp_model, + reference_model, + beta=config.beta, + nll_scale=config.nll_scale, + length_normalization=config.length_normalization, + ) + + dataset = context.get_dataset_as(LMDPODataset) + + tokenizer = context.get_tokenizer() + + batching: Batching + if config.dataset.batch_size is not None: + batching = StaticBatching(config.dataset.batch_size) + else: + batching = LengthBatching(config.dataset.max_num_tokens) + + read_options = LMDPODataReadOptions( + batching=batching, + example_shuffle_window=config.dataset.example_shuffle_window, + batch_shuffle_window=config.dataset.batch_shuffle_window, + num_accumulate=config.trainer.grad_accumulation.num_batches, + prefetch=config.dataset.prefetch, + source_encode_mode=config.dataset.source_encode_mode, + target_encode_mode=config.dataset.target_encode_mode, + chat_mode=config.dataset.chat_mode, + seed=config.common.seed, + ) + + data_reader = dataset.create_reader( + tokenizer=tokenizer, + gangs=context.gangs, + min_seq_len=config.dataset.min_seq_len, + max_seq_len=config.dataset.max_seq_len, + options=read_options, + ) + + return context.create_trainer(unit, data_reader, [], []) + + @property + @override + def config_kls(self) -> type[object]: + return LMDPOConfig + + +class LMDPOUnit(TrainUnit[PreferenceBatch]): + def __init__( + self, + model: Module, + reference_model: Module | None, + beta: float = 0.1, + nll_scale: float = 1.0, + length_normalization: bool = False, + ) -> None: + self._model = model + self._reference_model = reference_model + self._beta = beta + self._nll_scale = nll_scale + self._length_normalization = length_normalization + + @override + def prepare_metric_bag(self, metric_bag: MetricBag) -> None: # TODO: add metrics + add_nll_loss_metric(metric_bag) + + add_seq_batch_metrics(metric_bag) + + add_dpo_loss_metric(metric_bag) + + add_sequence_length_metrics(metric_bag) + + add_logps_metrics(metric_bag) + + def _compute_dpo_loss( + self, + chosen_logps: Tensor, + ref_chosen_logps: Tensor, + rejected_logps: Tensor, + ref_rejected_logps: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + logp_ratio_chosen = self._beta * (chosen_logps - ref_chosen_logps) + logp_ratio_rejected = self._beta * (rejected_logps - ref_rejected_logps) + dpo_loss = - Module.functional.logsigmoid( + logp_ratio_chosen - logp_ratio_rejected + ) + return logp_ratio_chosen, logp_ratio_rejected, dpo_loss.sum() + + @override + def process_batch( # TODO: update metrics + self, batch: PreferenceBatch, metric_bag: MetricBag + ) -> tuple[Tensor, int]: + model = cast(CausalLM, self._model) + reference_model = cast(CausalLM, self._reference_model) if self._reference_model else None + + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = chosen_batch.as_auto_regressive() + + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = ( + rejected_batch.as_auto_regressive() + ) + + if ( + chosen_target_batch.target_mask is None + or rejected_target_batch.target_mask is None + ): + raise RuntimeError("target_mask attributes must exist for DPO loss") + + chosen_seqs, chosen_seqs_layout = chosen_input_batch.as_input() + + nll_loss, chosen_logits = model( + chosen_seqs, + chosen_seqs_layout, + targets=chosen_target_batch.seqs, + target_mask=chosen_target_batch.target_mask, + return_logits=True, + ) + + rejected_seqs, rejected_seqs_layout = rejected_input_batch.as_input() + + rejected_logits = model(rejected_seqs, rejected_seqs_layout) + + chosen_logps, average_chosen_logps = _gather_lprobs_avg( + chosen_logits, chosen_target_batch + ) + rejected_logps, average_rejected_logps = _gather_lprobs_avg( + rejected_logits, rejected_target_batch + ) + + if reference_model is not None: + chosen_seqs, chosen_seqs_layout = chosen_batch.as_input() + rejected_seqs, rejected_seqs_layout = rejected_batch.as_input() + + with torch.no_grad(): + ref_chosen_logits = reference_model( + chosen_seqs, chosen_seqs_layout + ) + ref_rejected_logits = reference_model( + rejected_seqs, rejected_seqs_layout + ) + + ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg( + ref_chosen_logits, chosen_target_batch + ) + ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg( + ref_rejected_logits, rejected_target_batch + ) + elif ( + batch.reference_score_chosen is not None + and batch.reference_score_rejected is not None + ): + # reference scores must exist in the batch if reference model is None + ref_chosen_logps = batch.reference_score_chosen + ref_average_chosen_logps = ( + ref_chosen_logps / chosen_target_batch.target_mask.sum(-1) + ) + ref_rejected_logps = batch.reference_score_rejected + ref_average_rejected_logps = ( + ref_rejected_logps / rejected_target_batch.target_mask.sum(-1) + ) + else: + raise RuntimeError( + "Reference model is not initialized and data batch does not provide reference score, but at least one must exist." + ) + + if self._length_normalization: + _, _, dpo_loss = self._compute_dpo_loss( + average_chosen_logps, + ref_average_chosen_logps, + average_rejected_logps, + ref_average_rejected_logps, + ) + else: + _, _, dpo_loss = self._compute_dpo_loss( + chosen_logps, ref_chosen_logps, rejected_logps, ref_rejected_logps + ) + + update_dpo_loss_metric(metric_bag, dpo_loss, batch) + + update_nll_loss_metric(metric_bag, nll_loss, chosen_batch.num_target_elements) + + update_sequence_length_metrics(metric_bag, batch) + + update_logps_metrics(metric_bag, batch, chosen_logps, rejected_logps) + + update_seq_batch_metrics(metric_bag, chosen_batch) + + loss = ( + dpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements + ) # normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size