Skip to content
97 changes: 97 additions & 0 deletions recipes/lm/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -6,6 +6,14 @@

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TypeAlias
from fairseq2.datasets import SyncMode, SequenceBatch


import torch
from torch import Tensor

from fairseq2.logging import log
from fairseq2.models.clm import CausalLM
from fairseq2.nn import Embedding
Expand Down Expand Up @@ -48,3 +56,92 @@
return None

return embed


def _gather_lprobs_avg(logits: Tensor, target: SequenceBatch) -> tuple[Tensor, Tensor]:
assert target.target_mask is not None
logprobs = torch.log_softmax(logits, dim=-1)
per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1)
total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1]
assert target.target_mask is not None
average_logps = total_logps / target.target_mask.sum(-1)

return total_logps, average_logps


@dataclass
class StaticBatching:
"""Specifies batching where each batch has the same number of examples."""

batch_size: int
"""The number of examples in each batch."""


@dataclass
class LengthBatching:
"""Specifies batching where each batch has a maximum number of elements."""

max_num_elements: int
"""The maximum number of elements (e.g. tokens) in each batch."""


Batching: TypeAlias = StaticBatching | LengthBatching


@dataclass(kw_only=True)
class DataReadOptions:
batching: Batching = field(default_factory=lambda: StaticBatching(1))
"""The batching strategy for returned examples."""

example_shuffle_window: int = 0
"""
The size of the sliding window for shuffling examples. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by loading the
entire dataset.
"""

batch_shuffle_window: int = 0
"""
The size of the sliding window for shuffling batches. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by loading the
entire dataset.
"""

drop_remainder: bool = False
"""
If ``True``, drops the last set of batches if they have in total fewer
examples than requested.
"""

sync_batches: bool = True
"""
If ``True``, ensures that each process in the gang reads the same number of
batches. Typically used when the amount of data to be read can vary per
process (e.g. due to unbalanced sharding or non-static batching) and it is
critical for each process to iterate over the same number of batches (e.g.
during training).
"""

sync_mode: SyncMode = SyncMode.UNTIL_FIRST
"""
The data synchronization mode among processes in the gang. Only effective if
:attr:`sync_batches` is ``True``.
"""

max_num_batches: int | None = None
"""The maximum number of batches to return."""

num_accumulate: int = 1
"""
The number of batches to accumulate in each iteration. Typically used with
gradient accumulation during training.
"""

prefetch: int = 1
"""The number of batches to prefetch in background."""

npc: int = 10
"""The reference number of parallel calls that data reader can do."""

seed: int = 2
"""The seed to initialize the random number generators used internally."""
Empty file added recipes/lm/dpo/README.md
Empty file.
Empty file added recipes/lm/dpo/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions recipes/lm/dpo/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.recipe.cli import main

from .recipe import LMDPORecipe

recipe = LMDPORecipe()

main(recipe)
152 changes: 152 additions & 0 deletions recipes/lm/dpo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, field

from fairseq2.recipe.config import (
ADAMW_OPTIMIZER,
COSINE_ANNEALING_LR,
AdamWConfig,
CommonSection,
CosineAnnealingLRConfig,
DatasetSection,
GangSection,
LRSchedulerSection,
ModelSection,
OptimizerSection,
RegimeSection,
ReferenceModelSection,
TokenizerSection,
TorchConfig,
TrainerSection,
ActivationCheckpointingConfig,
)

from .dataset import LM_DPO_DATASET


@dataclass(kw_only=True)
class LMDPOConfig:
model: ModelSection = field(
default_factory=lambda: ModelSection(name="llama3_1_8b_instruct")
)

tokenizer: TokenizerSection = field(
default_factory=lambda: TokenizerSection(name="llama3_instruct")
)

dataset: LMDPODatasetSection = field(
default_factory=lambda: LMDPODatasetSection(family=LM_DPO_DATASET),
)

gang: GangSection = field(default_factory=lambda: GangSection())

trainer: TrainerSection = field(
default_factory=lambda: TrainerSection(
data_parallelism="fsdp",
max_grad_norm=1.0,
activation_checkpointing=ActivationCheckpointingConfig(mode="layerwise"),
)
)

optimizer: OptimizerSection = field(
default_factory=lambda: OptimizerSection(
name=ADAMW_OPTIMIZER,
config=AdamWConfig(
lr=5.5e-06, betas=(0.9, 0.95), weight_decay=0.1, impl="fused"
),
)
)

lr_scheduler: LRSchedulerSection | None = field(
default_factory=lambda: LRSchedulerSection(
name=COSINE_ANNEALING_LR, config=CosineAnnealingLRConfig(final_lr_scale=0.2)
)
)

regime: RegimeSection = field(
default_factory=lambda: RegimeSection(
num_steps=5000,
validate_every_n_steps=100,
checkpoint_every_n_steps=1000,
keep_last_n_checkpoints=1,
publish_metrics_every_n_steps=10,
export_hugging_face=True,
)
)

# The memory efficient SDPA implementation in PyTorch is numerically not
# stable when used with padded inputs.
common: CommonSection = field(
default_factory=lambda: CommonSection(
torch=TorchConfig(default_sdpa="torch_math")
)
)

# Loss configuration
reference_model: ReferenceModelSection | None = field(
default_factory=lambda: ReferenceModelSection(name="llama3_1_8b_instruct")
)
"""
The reference model. If ``None``, the recipe expects to get reference
log-probabilities for chosen and rejected targets as float values in the
data example (fields `reference_score_rejected` and `reference_score_chosen`).
"""

beta: float = 0.1
"""The coefficient of regularization towards the reference model."""

nll_scale: float = 0.0
"""The coefficient of NLL loss added to the DPO loss."""

length_normalization: bool = False
"""Use length normalized DPO, which uses the average log probability of a sequence as the implicit reward."""



@dataclass(kw_only=True)
class LMDPODatasetSection(DatasetSection):
path: str | None = None

source_encode_mode: str = "prompt"
"""The encode mode for the prompt, determines what special tokens to add."""

target_encode_mode: str = "prompt_response"
"""The encode mode for the target, determines what special tokens to add."""

mask_source_tokens: bool = True
"""If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens."""

min_seq_len: int = 1
"""The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``.
Shorter sequences will be dropped."""

max_seq_len: int = 8192
"""The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``.
Longer sequences will be dropped."""

max_num_tokens: int = 8192 * 2
"""The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch."""

batch_size: int | None = None
"""If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples."""

example_shuffle_window: int = 10_000
"""The size of the sliding window for shuffling examples."""

batch_shuffle_window: int = 1_000
"""The size of the sliding window for shuffling batches."""

num_prefetch: int = 4
"""The number of batches to prefetch in background."""

extras: dict[str, object] = field(default_factory=dict)
"""The dataset-specific extra options."""

chat_mode: bool = False
"""If True, dataset jsonl must have 'chat' field with openai-like messages List[Dict] entries"""
Loading
Loading