diff --git a/recipes/lm/common.py b/recipes/lm/common.py index d0198ffc1..bff6a2f3c 100644 --- a/recipes/lm/common.py +++ b/recipes/lm/common.py @@ -6,6 +6,14 @@ from __future__ import annotations +from dataclasses import dataclass, field +from typing import TypeAlias +from fairseq2.datasets import SyncMode, SequenceBatch + + +import torch +from torch import Tensor + from fairseq2.logging import log from fairseq2.models.clm import CausalLM from fairseq2.nn import Embedding @@ -48,3 +56,92 @@ def _maybe_get_embed(model: CausalLM) -> Embedding | None: return None return embed + + +def _gather_lprobs_avg(logits: Tensor, target: SequenceBatch) -> tuple[Tensor, Tensor]: + assert target.target_mask is not None + logprobs = torch.log_softmax(logits, dim=-1) + per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) + total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] + assert target.target_mask is not None + average_logps = total_logps / target.target_mask.sum(-1) + + return total_logps, average_logps + + +@dataclass +class StaticBatching: + """Specifies batching where each batch has the same number of examples.""" + + batch_size: int + """The number of examples in each batch.""" + + +@dataclass +class LengthBatching: + """Specifies batching where each batch has a maximum number of elements.""" + + max_num_elements: int + """The maximum number of elements (e.g. tokens) in each batch.""" + + +Batching: TypeAlias = StaticBatching | LengthBatching + + +@dataclass(kw_only=True) +class DataReadOptions: + batching: Batching = field(default_factory=lambda: StaticBatching(1)) + """The batching strategy for returned examples.""" + + example_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling examples. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + batch_shuffle_window: int = 0 + """ + The size of the sliding window for shuffling batches. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + drop_remainder: bool = False + """ + If ``True``, drops the last set of batches if they have in total fewer + examples than requested. + """ + + sync_batches: bool = True + """ + If ``True``, ensures that each process in the gang reads the same number of + batches. Typically used when the amount of data to be read can vary per + process (e.g. due to unbalanced sharding or non-static batching) and it is + critical for each process to iterate over the same number of batches (e.g. + during training). + """ + + sync_mode: SyncMode = SyncMode.UNTIL_FIRST + """ + The data synchronization mode among processes in the gang. Only effective if + :attr:`sync_batches` is ``True``. + """ + + max_num_batches: int | None = None + """The maximum number of batches to return.""" + + num_accumulate: int = 1 + """ + The number of batches to accumulate in each iteration. Typically used with + gradient accumulation during training. + """ + + prefetch: int = 1 + """The number of batches to prefetch in background.""" + + npc: int = 10 + """The reference number of parallel calls that data reader can do.""" + + seed: int = 2 + """The seed to initialize the random number generators used internally.""" diff --git a/recipes/lm/dpo/README.md b/recipes/lm/dpo/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/recipes/lm/dpo/__init__.py b/recipes/lm/dpo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/recipes/lm/dpo/__main__.py b/recipes/lm/dpo/__main__.py new file mode 100644 index 000000000..b545194de --- /dev/null +++ b/recipes/lm/dpo/__main__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.recipe.cli import main + +from .recipe import LMDPORecipe + +recipe = LMDPORecipe() + +main(recipe) diff --git a/recipes/lm/dpo/config.py b/recipes/lm/dpo/config.py new file mode 100644 index 000000000..8c03a6f5b --- /dev/null +++ b/recipes/lm/dpo/config.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from fairseq2.recipe.config import ( + ADAMW_OPTIMIZER, + COSINE_ANNEALING_LR, + AdamWConfig, + CommonSection, + CosineAnnealingLRConfig, + DatasetSection, + GangSection, + LRSchedulerSection, + ModelSection, + OptimizerSection, + RegimeSection, + ReferenceModelSection, + TokenizerSection, + TorchConfig, + TrainerSection, + ActivationCheckpointingConfig, +) + +from .dataset import LM_DPO_DATASET + + +@dataclass(kw_only=True) +class LMDPOConfig: + model: ModelSection = field( + default_factory=lambda: ModelSection(name="llama3_1_8b_instruct") + ) + + tokenizer: TokenizerSection = field( + default_factory=lambda: TokenizerSection(name="llama3_instruct") + ) + + dataset: LMDPODatasetSection = field( + default_factory=lambda: LMDPODatasetSection(family=LM_DPO_DATASET), + ) + + gang: GangSection = field(default_factory=lambda: GangSection()) + + trainer: TrainerSection = field( + default_factory=lambda: TrainerSection( + data_parallelism="fsdp", + max_grad_norm=1.0, + activation_checkpointing=ActivationCheckpointingConfig(mode="layerwise"), + ) + ) + + optimizer: OptimizerSection = field( + default_factory=lambda: OptimizerSection( + name=ADAMW_OPTIMIZER, + config=AdamWConfig( + lr=5.5e-06, betas=(0.9, 0.95), weight_decay=0.1, impl="fused" + ), + ) + ) + + lr_scheduler: LRSchedulerSection | None = field( + default_factory=lambda: LRSchedulerSection( + name=COSINE_ANNEALING_LR, config=CosineAnnealingLRConfig(final_lr_scale=0.2) + ) + ) + + regime: RegimeSection = field( + default_factory=lambda: RegimeSection( + num_steps=5000, + validate_every_n_steps=100, + checkpoint_every_n_steps=1000, + keep_last_n_checkpoints=1, + publish_metrics_every_n_steps=10, + export_hugging_face=True, + ) + ) + + # The memory efficient SDPA implementation in PyTorch is numerically not + # stable when used with padded inputs. + common: CommonSection = field( + default_factory=lambda: CommonSection( + torch=TorchConfig(default_sdpa="torch_math") + ) + ) + + # Loss configuration + reference_model: ReferenceModelSection | None = field( + default_factory=lambda: ReferenceModelSection(name="llama3_1_8b_instruct") + ) + """ + The reference model. If ``None``, the recipe expects to get reference + log-probabilities for chosen and rejected targets as float values in the + data example (fields `reference_score_rejected` and `reference_score_chosen`). + """ + + beta: float = 0.1 + """The coefficient of regularization towards the reference model.""" + + nll_scale: float = 0.0 + """The coefficient of NLL loss added to the DPO loss.""" + + length_normalization: bool = False + """Use length normalized DPO, which uses the average log probability of a sequence as the implicit reward.""" + + + +@dataclass(kw_only=True) +class LMDPODatasetSection(DatasetSection): + path: str | None = None + + source_encode_mode: str = "prompt" + """The encode mode for the prompt, determines what special tokens to add.""" + + target_encode_mode: str = "prompt_response" + """The encode mode for the target, determines what special tokens to add.""" + + mask_source_tokens: bool = True + """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" + + min_seq_len: int = 1 + """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Shorter sequences will be dropped.""" + + max_seq_len: int = 8192 + """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Longer sequences will be dropped.""" + + max_num_tokens: int = 8192 * 2 + """The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch.""" + + batch_size: int | None = None + """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" + + example_shuffle_window: int = 10_000 + """The size of the sliding window for shuffling examples.""" + + batch_shuffle_window: int = 1_000 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + extras: dict[str, object] = field(default_factory=dict) + """The dataset-specific extra options.""" + + chat_mode: bool = False + """If True, dataset jsonl must have 'chat' field with openai-like messages List[Dict] entries""" diff --git a/recipes/lm/dpo/dataset.py b/recipes/lm/dpo/dataset.py new file mode 100644 index 000000000..2d2be83a0 --- /dev/null +++ b/recipes/lm/dpo/dataset.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Final, TypeAlias, cast + +import torch +from typing_extensions import override + +# from fairseq2.assets.download_manager import get_asset_download_manager +from fairseq2.assets.download_manager import AssetDownloadManager +from fairseq2.runtime.dependency import get_dependency_resolver +from fairseq2.data import ( + CollateOptionsOverride, + Collater, + SequenceData, + create_bucket_sizes, +) +from fairseq2.data.data_pipeline import DataPipeline, read_sequence +from fairseq2.data.text import read_text +from fairseq2.data.tokenizers import Tokenizer +from fairseq2.data.tokenizers.hg import HuggingFaceTokenEncoder +from fairseq2.datasets import DataPipelineReader, SequenceBatch, PreferenceBatch +from fairseq2.error import NotSupportedError, raise_operational_system_error +from fairseq2.gang import Gangs +from fairseq2.utils.uri import Uri + +from ..common import DataReadOptions, LengthBatching, StaticBatching + + +LM_DPO_DATASET: Final = "lm_dpo" + + +@dataclass(kw_only=True) +class LMDPODataReadOptions(DataReadOptions): + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + mask_source_tokens: bool = True + """ + If ``False``, calculates loss on the source tokens (prompt) as well as the + target tokens. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" + + chat_mode: bool = False + + +class LMDPODataset: + def __init__(self, sources: dict[str, list[LMDPODataSource]]) -> None: + self._sources = sources + + def _create_path_reader( + self, path: str, gangs: Gangs, shuffle_window: int, seed: int + ) -> DataPipeline: + # download_manager = get_asset_download_manager() + download_manager = get_dependency_resolver().resolve(AssetDownloadManager) + + uri = Uri.maybe_parse(path) + if uri: + local_path = download_manager.download_dataset(uri) + else: + local_path = Path(path) + + if not local_path.is_dir(): + files = [local_path] + else: + try: + files = [f for f in local_path.glob("**/*.jsonl") if not f.is_dir()] + except OSError as ex: + raise_operational_system_error(ex) + + files.sort() + + builder = read_sequence(files) + + def read_file(file: Path) -> DataPipeline: + return read_text(file).map(json.loads).and_return() + + builder.yield_from(read_file) + + if shuffle_window != 1: + builder.shuffle(shuffle_window, seed=seed) + + builder.shard(gangs.dp.rank, gangs.dp.size, allow_uneven=True) + + return builder.and_return() + + def create_reader( + self, + tokenizer: Tokenizer, + gangs: Gangs, + min_seq_len: int, + max_seq_len: int, + options: LMDPODataReadOptions | None = None, + ) -> DataPipelineReader[PreferenceBatch]: + if options is None: + options = LMDPODataReadOptions() + + seed = options.seed + + pipelines = [] + + weights = [] + + for source in self._sources: + pipeline = self._create_path_reader( + source.path, gangs, options.example_shuffle_window, seed + ) + + seed += 1 + + pipelines.append(pipeline) + + weights.append(source.weight) + + seed += gangs.dp.rank + + builder = DataPipeline.sample(pipelines, weights, seed) + + if options.chat_mode is True: + # not passing any encoding modes here, because we use apply_chat_template here + encoder = tokenizer.create_encoder() + if not isinstance(encoder, HuggingFaceTokenEncoder): + raise RuntimeError( + "Huggingface tokenizer must be used when chat_mode is True" + ) + else: + + def encoding_chat(example: dict[str, Any]) -> dict[str, Any]: + id_ = example.get("id") + chat_chosen = example.get("chat_chosen") + chat_rejected = example.get("chat_rejected") + + encoded_output_chosen = encoder.apply_chat_template( + chat_chosen, + return_dict=True, + return_assistant_tokens_mask=True, + return_tensors="pt", + ) + + indices_chosen = encoded_output_chosen["input_ids"][0] + target_mask_chosen = encoded_output_chosen["assistant_masks"][ + 0 + ].bool() + + encoded_output_rejected = encoder.apply_chat_template( + chat_rejected, + return_dict=True, + return_assistant_tokens_mask=True, + return_tensors="pt", + ) + + indices_rejected = encoded_output_rejected["input_ids"][0] + target_mask_rejected = encoded_output_rejected["assistant_masks"][0].bool() + + breakpoint() + if not options.mask_source_tokens: + # no source masking i.e. mask has all 1s + target_mask_chosen._fill(True) + target_mask_rejected._fill(True) + + total_tokens = len(indices_chosen) + len(indices_rejected) + + return { + "id": id_, + "indices_chosen": indices_chosen, + "indices_rejected": indices_rejected, + "reference_score_chosen": example.get( + "reference_score_chosen", None + ), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), + "target_mask_chosen": target_mask_chosen, + "target_mask_rejected": target_mask_rejected, + "total_tokens": total_tokens, + } + + builder.map(encoding_chat) + + else: + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) + + builder.map(source_encoder, selector="src") + builder.map(target_encoder, selector="tgt_chosen") + builder.map(target_encoder, selector="tgt_rejected") + + def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: + id_ = example.get("id", None) + + source_indices = example["src"] + target_indices_chosen = example["tgt_chosen"] + target_indices_rejected = example["tgt_rejected"] + + indices_chosen = torch.cat([source_indices, target_indices_chosen]) + indices_rejected = torch.cat([source_indices, target_indices_rejected]) + + if options.mask_source_tokens: + source_len = len(source_indices) + target_mask_chosen = torch.arange(len(indices_chosen)) >= source_len + target_mask_rejected = ( + torch.arange(len(indices_rejected)) >= source_len + ) + else: + target_mask_chosen = torch.full([len(indices_chosen)], True) + target_mask_rejected = torch.full([len(indices_rejected)], True) + + total_tokens = ( + 2 * len(source_indices) + + len(target_indices_chosen) + + len(target_indices_rejected) + ) + + return { + "id": id_, + "indices_prompt": source_indices, + "indices_chosen": indices_chosen, + "indices_rejected": indices_rejected, + "reference_score_chosen": example.get( + "reference_score_chosen", None + ), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), + "target_mask_chosen": target_mask_chosen, + "target_mask_rejected": target_mask_rejected, + "total_tokens": total_tokens, + } + + builder.map(cat_source_and_target) + + batching = options.batching + + if isinstance(batching, LengthBatching): + bucket_sizes = create_bucket_sizes( + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, + ) + + # Bucket by the sequence length + builder.bucket_by_length( + bucket_sizes, + selector="total_tokens", + min_data_len=min_seq_len, + skip_above_max_examples=True, + drop_remainder=options.drop_remainder, + ) + elif isinstance(batching, StaticBatching): + # Filter out long examples + def skip(example: dict[str, Any]) -> bool: + chosen_len = len(example["indices_chosen"]) + rejected_len = len(example["indices_rejected"]) + + if chosen_len > max_seq_len or rejected_len > max_seq_len: + return False + + return chosen_len >= min_seq_len and rejected_len >= min_seq_len + + builder.filter(skip) + + # Bucket `batch_size` examples + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") + + # Shuffle buckets + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) + + seed += 1 + + # Collate bucketed examples into a batch. + target_mask_collate_opts = [ + CollateOptionsOverride("target_mask_chosen", pad_value=False), + CollateOptionsOverride("target_mask_rejected", pad_value=False), + ] + + collater = Collater(pad_value=0, overrides=target_mask_collate_opts) + + builder.map(collater, num_parallel_calls=options.npc) + + # Return only the first `max_num_batches`. + if options.max_num_batches is not None: + builder.take(options.max_num_batches) + + # Prefetch `prefetch` batches in background. + builder.prefetch(options.prefetch) + + # Wrap examples with `PreferenceBatch`. + def to_batch(example: dict[str, Any]) -> PreferenceBatch: + indices_chosen = cast(SequenceData, example["indices_chosen"]) + indices_rejected = cast(SequenceData, example["indices_rejected"]) + + seqs_chosen, seq_chosen_lens = ( + indices_chosen["seqs"], + indices_chosen["seq_lens"], + ) + seqs_rejected, seq_rejected_lens = ( + indices_rejected["seqs"], + indices_rejected["seq_lens"], + ) + + target_mask_chosen = example["target_mask_chosen"]["seqs"] + target_mask_rejected = example["target_mask_rejected"]["seqs"] + + batch_chosen = SequenceBatch( + seqs_chosen, + seq_chosen_lens, + target_mask=target_mask_chosen, + example=example, + ) + + batch_rejected = SequenceBatch( + seqs_rejected, + seq_rejected_lens, + target_mask=target_mask_rejected, + example=example, + ) + + batch_reference_scores_chosen = None + if all(example["reference_score_chosen"]): + batch_reference_scores_chosen = torch.Tensor( + example["reference_score_chosen"] + ) + batch_reference_scores_rejected = None + if all(example["reference_score_rejected"]): + batch_reference_scores_rejected = torch.Tensor( + example["reference_score_rejected"] + ) + + return PreferenceBatch( + batch_chosen, + batch_rejected, + batch_reference_scores_chosen, + batch_reference_scores_rejected, + ) + + pipeline = builder.map(to_batch).and_return() + + return DataPipelineReader[PreferenceBatch]( + pipeline, + gangs, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync=options.sync_batches, + sync_mode=options.sync_mode, + ) + + +@dataclass +class LMDPODataSource: #TODO: to confirm after recipe.py + path: str + split: str | None = None + weight: float = 1.0 + + +@dataclass +class LMDPODatasetConfig: #TODO: to confirm after recipe.py + sources: dict[str, list[LMDPODataSource]] = field(default_factory=dict) + + +def open_lm_dpo_dataset(config: LMDPODatasetConfig) -> LMDPODataset: + return LMDPODataset(config.sources) diff --git a/recipes/lm/dpo/recipe.py b/recipes/lm/dpo/recipe.py new file mode 100644 index 000000000..24d280e0b --- /dev/null +++ b/recipes/lm/dpo/recipe.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor +from torch.nn import Module +from typing_extensions import override + +from fairseq2.composition import register_dataset_family +from fairseq2.datasets import SequenceBatch +from fairseq2.metrics import MetricBag +from fairseq2.metrics.common import ( + add_nll_loss_metric, + add_seq_batch_metrics, + add_dpo_loss_metric, + add_sequence_length_metrics, + add_logps_metrics, + update_nll_loss_metric, + update_seq_batch_metrics, + update_dpo_loss_metric, + update_sequence_length_metrics, + update_logps_metrics, +) +from fairseq2.models.clm import CausalLM +from fairseq2.recipe.base import Recipe, RecipeContext +from fairseq2.recipe.trainer import TrainUnit +from fairseq2.runtime.dependency import DependencyContainer +from fairseq2.task import Task + +from ..common import check_model_vocabulary, Batching, LengthBatching, StaticBatching, _gather_lprobs_avg +from .config import LMDPOConfig +from .dataset import ( + LM_DPO_DATASET, + LMDPODataReadOptions, + LMDPODataset, + LMDPODatasetConfig, + open_lm_dpo_dataset, + PreferenceBatch, +) + + +class LMDPORecipe(Recipe): + @override + def register(self, container: DependencyContainer) -> None: + register_dataset_family( + container, + LM_DPO_DATASET, + LMDPODataset, + LMDPODatasetConfig, + opener=open_lm_dpo_dataset, + ) + + @override + def create_task(self, context: RecipeContext) -> Task: + config = context.get_config_as(LMDPOConfig) + + check_model_vocabulary(context) + + dp_model = context.get_data_parallel_model() + + # Get reference model if configured + reference_model = None + if config.reference_model is not None: + reference_model = context.bootstrap_reference_model("reference_model") + + unit = LMDPOUnit( + dp_model, + reference_model, + beta=config.beta, + nll_scale=config.nll_scale, + length_normalization=config.length_normalization, + ) + + dataset = context.get_dataset_as(LMDPODataset) + + tokenizer = context.get_tokenizer() + + batching: Batching + if config.dataset.batch_size is not None: + batching = StaticBatching(config.dataset.batch_size) + else: + batching = LengthBatching(config.dataset.max_num_tokens) + + read_options = LMDPODataReadOptions( + batching=batching, + example_shuffle_window=config.dataset.example_shuffle_window, + batch_shuffle_window=config.dataset.batch_shuffle_window, + num_accumulate=config.trainer.grad_accumulation.num_batches, + prefetch=config.dataset.prefetch, + source_encode_mode=config.dataset.source_encode_mode, + target_encode_mode=config.dataset.target_encode_mode, + chat_mode=config.dataset.chat_mode, + seed=config.common.seed, + ) + + data_reader = dataset.create_reader( + tokenizer=tokenizer, + gangs=context.gangs, + min_seq_len=config.dataset.min_seq_len, + max_seq_len=config.dataset.max_seq_len, + options=read_options, + ) + + return context.create_trainer(unit, data_reader, [], []) + + @property + @override + def config_kls(self) -> type[object]: + return LMDPOConfig + + +class LMDPOUnit(TrainUnit[PreferenceBatch]): + def __init__( + self, + model: Module, + reference_model: Module | None, + beta: float = 0.1, + nll_scale: float = 1.0, + length_normalization: bool = False, + ) -> None: + self._model = model + self._reference_model = reference_model + self._beta = beta + self._nll_scale = nll_scale + self._length_normalization = length_normalization + + @override + def prepare_metric_bag(self, metric_bag: MetricBag) -> None: # TODO: add metrics + add_nll_loss_metric(metric_bag) + + add_seq_batch_metrics(metric_bag) + + add_dpo_loss_metric(metric_bag) + + add_sequence_length_metrics(metric_bag) + + add_logps_metrics(metric_bag) + + def _compute_dpo_loss( + self, + chosen_logps: Tensor, + ref_chosen_logps: Tensor, + rejected_logps: Tensor, + ref_rejected_logps: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + logp_ratio_chosen = self._beta * (chosen_logps - ref_chosen_logps) + logp_ratio_rejected = self._beta * (rejected_logps - ref_rejected_logps) + dpo_loss = - Module.functional.logsigmoid( + logp_ratio_chosen - logp_ratio_rejected + ) + return logp_ratio_chosen, logp_ratio_rejected, dpo_loss.sum() + + @override + def process_batch( # TODO: update metrics + self, batch: PreferenceBatch, metric_bag: MetricBag + ) -> tuple[Tensor, int]: + model = cast(CausalLM, self._model) + reference_model = cast(CausalLM, self._reference_model) if self._reference_model else None + + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = chosen_batch.as_auto_regressive() + + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = ( + rejected_batch.as_auto_regressive() + ) + + if ( + chosen_target_batch.target_mask is None + or rejected_target_batch.target_mask is None + ): + raise RuntimeError("target_mask attributes must exist for DPO loss") + + chosen_seqs, chosen_seqs_layout = chosen_input_batch.as_input() + + nll_loss, chosen_logits = model( + chosen_seqs, + chosen_seqs_layout, + targets=chosen_target_batch.seqs, + target_mask=chosen_target_batch.target_mask, + return_logits=True, + ) + + rejected_seqs, rejected_seqs_layout = rejected_input_batch.as_input() + + rejected_logits = model(rejected_seqs, rejected_seqs_layout) + + chosen_logps, average_chosen_logps = _gather_lprobs_avg( + chosen_logits, chosen_target_batch + ) + rejected_logps, average_rejected_logps = _gather_lprobs_avg( + rejected_logits, rejected_target_batch + ) + + if reference_model is not None: + chosen_seqs, chosen_seqs_layout = chosen_batch.as_input() + rejected_seqs, rejected_seqs_layout = rejected_batch.as_input() + + with torch.no_grad(): + ref_chosen_logits = reference_model( + chosen_seqs, chosen_seqs_layout + ) + ref_rejected_logits = reference_model( + rejected_seqs, rejected_seqs_layout + ) + + ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg( + ref_chosen_logits, chosen_target_batch + ) + ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg( + ref_rejected_logits, rejected_target_batch + ) + elif ( + batch.reference_score_chosen is not None + and batch.reference_score_rejected is not None + ): + # reference scores must exist in the batch if reference model is None + ref_chosen_logps = batch.reference_score_chosen + ref_average_chosen_logps = ( + ref_chosen_logps / chosen_target_batch.target_mask.sum(-1) + ) + ref_rejected_logps = batch.reference_score_rejected + ref_average_rejected_logps = ( + ref_rejected_logps / rejected_target_batch.target_mask.sum(-1) + ) + else: + raise RuntimeError( + "Reference model is not initialized and data batch does not provide reference score, but at least one must exist." + ) + + if self._length_normalization: + _, _, dpo_loss = self._compute_dpo_loss( + average_chosen_logps, + ref_average_chosen_logps, + average_rejected_logps, + ref_average_rejected_logps, + ) + else: + _, _, dpo_loss = self._compute_dpo_loss( + chosen_logps, ref_chosen_logps, rejected_logps, ref_rejected_logps + ) + + update_dpo_loss_metric(metric_bag, dpo_loss, batch) + + update_nll_loss_metric(metric_bag, nll_loss, chosen_batch.num_target_elements) + + update_sequence_length_metrics(metric_bag, batch) + + update_logps_metrics(metric_bag, batch, chosen_logps, rejected_logps) + + update_seq_batch_metrics(metric_bag, chosen_batch) + + loss = ( + dpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements + ) # normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size diff --git a/recipes/lm/sft/dataset.py b/recipes/lm/sft/dataset.py index 29e527289..065070d4b 100644 --- a/recipes/lm/sft/dataset.py +++ b/recipes/lm/sft/dataset.py @@ -29,86 +29,13 @@ from fairseq2.gang import Gangs from fairseq2.utils.uri import Uri -LM_SFT_DATASET: Final = "lm_sft" - - -@dataclass -class StaticBatching: - """Specifies batching where each batch has the same number of examples.""" - - batch_size: int - """The number of examples in each batch.""" - - -@dataclass -class LengthBatching: - """Specifies batching where each batch has a maximum number of elements.""" - - max_num_elements: int - """The maximum number of elements (e.g. tokens) in each batch.""" +from ..common import DataReadOptions, LengthBatching, StaticBatching - -Batching: TypeAlias = StaticBatching | LengthBatching +LM_SFT_DATASET: Final = "lm_sft" @dataclass(kw_only=True) -class DataReadOptions: - batching: Batching = field(default_factory=lambda: StaticBatching(1)) - """The batching strategy for returned examples.""" - - example_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ - - batch_shuffle_window: int = 0 - """ - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by loading the - entire dataset. - """ - - drop_remainder: bool = False - """ - If ``True``, drops the last set of batches if they have in total fewer - examples than requested. - """ - - sync_batches: bool = True - """ - If ``True``, ensures that each process in the gang reads the same number of - batches. Typically used when the amount of data to be read can vary per - process (e.g. due to unbalanced sharding or non-static batching) and it is - critical for each process to iterate over the same number of batches (e.g. - during training). - """ - - sync_mode: SyncMode = SyncMode.UNTIL_FIRST - """ - The data synchronization mode among processes in the gang. Only effective if - :attr:`sync_batches` is ``True``. - """ - - max_num_batches: int | None = None - """The maximum number of batches to return.""" - - num_accumulate: int = 1 - """ - The number of batches to accumulate in each iteration. Typically used with - gradient accumulation during training. - """ - - prefetch: int = 1 - """The number of batches to prefetch in background.""" - - npc: int = 10 - """The reference number of parallel calls that data reader can do.""" - - seed: int = 2 - """The seed to initialize the random number generators used internally.""" - +class LMSFTDataReadOptions(DataReadOptions): sample: bool = False """ If ``True``, instruction sources (e.g. JSONL files) will be sampled in @@ -173,10 +100,10 @@ def create_reader( gangs: Gangs, min_seq_len: int, max_seq_len: int, - options: DataReadOptions | None = None, + options: LMSFTDataReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: if options is None: - options = DataReadOptions() + options = LMSFTDataReadOptions() sources = self._sources[split] diff --git a/recipes/lm/sft/recipe.py b/recipes/lm/sft/recipe.py index 1e9781251..db92dc17c 100644 --- a/recipes/lm/sft/recipe.py +++ b/recipes/lm/sft/recipe.py @@ -33,7 +33,7 @@ from .dataset import ( LM_SFT_DATASET, Batching, - DataReadOptions, + LMSFTDataReadOptions, LengthBatching, LMSFTDataset, LMSFTDatasetConfig, @@ -73,7 +73,7 @@ def create_task(self, context: RecipeContext) -> Task: else: batching = LengthBatching(config.dataset.max_num_tokens) - read_options = DataReadOptions( + read_options = LMSFTDataReadOptions( batching=batching, example_shuffle_window=config.dataset.example_shuffle_window, batch_shuffle_window=config.dataset.batch_shuffle_window, @@ -107,7 +107,7 @@ def create_task(self, context: RecipeContext) -> Task: valid_batching = LengthBatching(max_num_tokens) - read_options = DataReadOptions( + read_options = LMSFTDataReadOptions( batching=valid_batching, prefetch=config.dataset.prefetch, source_encode_mode=config.dataset.source_encode_mode, diff --git a/src/fairseq2/datasets/batch.py b/src/fairseq2/datasets/batch.py index 28618fe2c..057dd1369 100644 --- a/src/fairseq2/datasets/batch.py +++ b/src/fairseq2/datasets/batch.py @@ -9,6 +9,7 @@ from collections.abc import Sequence from typing import final +import torch from torch import Tensor from typing_extensions import override @@ -662,3 +663,34 @@ def __repr__(self) -> str: ) return f"Seq2SeqBatch({s})" + + +@final +class PreferenceBatch(SupportsDeviceTransfer): + """Represents a preference optimization dataset batch.""" + + chosen: SequenceBatch + rejected: SequenceBatch + reference_score_chosen: torch.Tensor | None + reference_score_rejected: torch.Tensor | None + + @property + def batch_size(self) -> int: + """The size of the batch dimension.""" + return self.chosen.batch_size + + @override + def to(self, device: Device, *, non_blocking: bool = False) -> None: + self.chosen.to(device, non_blocking=non_blocking) + + self.rejected.to(device, non_blocking=non_blocking) + + if self.reference_score_chosen is not None: + self.reference_score_chosen = self.reference_score_chosen.to( + device, non_blocking=non_blocking + ) + + if self.reference_score_rejected is not None: + self.reference_score_rejected = self.reference_score_rejected.to( + device, non_blocking=non_blocking + )