diff --git a/src/forge/data_models/__init__.py b/src/forge/data_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/forge/data_models/api.py b/src/forge/data_models/api.py new file mode 100644 index 000000000..bc4a5981a --- /dev/null +++ b/src/forge/data_models/api.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence +from typing import Dict, Tuple + +import torch + +from forge.data_models.completion import Completion +from forge.data_models.loss import LossOutput +from forge.data_models.minibatch import Minibatch + +from forge.data_models.prompt import Prompt +from forge.data_models.scored_completion import ScoredCompletion +from torch.utils.data import DataLoader, IterableDataset + + +# TODO: This file needs should not be in the data_models folder/package + + +class Trainer(ABC): + @abstractmethod + def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: + """ + accummulate_gradients is called once per minibatch. + """ + pass + + @abstractmethod + def apply_gradients(self) -> None: + """ + applying accumulated gradients to the model parameters. + """ + pass + + @abstractmethod + def snapshot_weights( + self, + ) -> Dict[str, Tuple[torch.Tensor, torch.dtype, torch.Size]]: # TODO: RDMA buffer + """ + applying accumulated gradients to the model parameters. + + the return type is a tuple of weights buffer, dtype, and shape of the original tensor. + """ + # TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle + pass + + +class Generator(ABC): + @abstractmethod + def generate(self, prompt: Prompt, **kwargs) -> list[Completion]: + """ + Generate a completion given a prompt. + Args: + prompt: The input prompt. + **kwargs: Additional model-specific generation parameters. + Returns: + str: The generated text. + """ + pass + + @abstractmethod + def update_weights( + self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]] + ): + """ + Update the weights of the model. + Args: + weights: A dictionary of weights to update. + """ + # TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle + pass + + +class Scorer(ABC): + @abstractmethod + def score(self, completion: Completion) -> ScoredCompletion: + pass + + def score_batch( + self, completions: Sequence[Completion] + ) -> Sequence[ScoredCompletion]: + """ + Optionally override for efficient batch scoring. + """ + return [self.score(c) for c in completions] + + +class PromptDataset(IterableDataset): + """ + Users should inherit from this and implement __iter__. + """ + + def __iter__(self) -> Iterator[Prompt]: + """ + defines how to generate or yield SimpleElement objects. + """ + raise NotImplementedError + + +class PromptDataLoader(DataLoader): + """ + subclass of DataLoader to handles batching, parallelism, and other data serving concerns. + """ + + def __init__(self, dataset: PromptDataset, **kwargs): + super().__init__(dataset, **kwargs) diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py new file mode 100644 index 000000000..eca4f62fe --- /dev/null +++ b/src/forge/data_models/completion.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from forge.data_models.prompt import Prompt + + +@dataclass +class Completion: + """A model-generated completion for a given prompt.""" + + # The original prompt. + prompt: Prompt + + # the decoded text returned by the model + text: str + + # the encoded text (token ids) that were fed into the model + prompt_ids: torch.Tensor + + # the encoded text (token ids) that were generated by the model + token_ids: torch.Tensor + + # the log probabilities of the target tokens + log_probs: Optional[torch.Tensor] = None diff --git a/src/forge/data_models/distributed_metric.py b/src/forge/data_models/distributed_metric.py new file mode 100644 index 000000000..5fe6f0fb9 --- /dev/null +++ b/src/forge/data_models/distributed_metric.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch +import torch.distributed as dist + + +@dataclass +class DistributedMetric(ABC): + """Metrics that are calculated in distributed fashion. + + Metrics computed in each rank are going to be wrapped in DistributedMetric + according to how they are going to be aggregated. For example, average log prob + can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where + `mask` indicates which token is valid. + """ + + # We need to pass a context argument for distribution setup in the future. + @abstractmethod + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + pass + + @abstractmethod + def local(self) -> torch.Tensor: + pass + + +@dataclass +class SumDistributedMetric(DistributedMetric): + def __init__(self, tensor: torch.Tensor) -> None: + self.tensor = tensor + + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) + + def local(self) -> torch.Tensor: + return self.tensor + + +@dataclass +class Fraction: + numerator: DistributedMetric + denominator: DistributedMetric + + def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: + return self.numerator.reduce(group) / self.denominator.reduce(group) + + def local(self) -> torch.Tensor: + return self.numerator.local() / self.denominator.local() + + +def _try_clone_and_reduce( + tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None +) -> torch.Tensor: + cloned = tensor.detach().clone() + if dist.is_initialized(): + dist.all_reduce(cloned, op=op, group=group) + return cloned diff --git a/src/forge/data_models/experience.py b/src/forge/data_models/experience.py new file mode 100644 index 000000000..34a183eba --- /dev/null +++ b/src/forge/data_models/experience.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +from forge.data_models.scored_completion import ScoredCompletion + + +@dataclass +class Experience: + """ + The Experience data class to be used by the trainer. + + Experiences are usually generated from a scored completion and running various post processing steps. + """ + + # Concatenated prompt and sample token ids. + ids: torch.Tensor + + # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. + mask: torch.Tensor + + # The weight to apply to the loss of each target token. It's normally computed + # from the advantage and the reward. + weights: torch.Tensor + + # The log probabilities of the target tokens, for prompt part it's set to 0, + # for generation part it's computed from the Generator/Sampler. + log_probs: Optional[torch.Tensor] = None + + # TODO: add more fields as required + state: str = "" + + +def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: + """Converts a ScoredCompletion to an Experience.""" + prompt_ids = scored_completion.completion.prompt_ids + token_ids = scored_completion.completion.token_ids + log_probs = scored_completion.completion.log_probs + ids = torch.cat([prompt_ids, token_ids]) + mask = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + torch.ones_like(token_ids, dtype=torch.float32), + ] + ) + advantage = scored_completion.score + weights = mask * advantage + log_probs = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + # TODO: this only works if sample.log_probs is 1 + log_probs, + ] + ) + return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) + + +def from_scored_completions( + scored_completions: Sequence[ScoredCompletion], +) -> Sequence[Experience]: + """Converts a sequence of ScoredCompletion to a sequence of Experiences.""" + return [from_scored_completion(sc) for sc in scored_completions] diff --git a/src/forge/data_models/loss.py b/src/forge/data_models/loss.py new file mode 100644 index 000000000..9806938e5 --- /dev/null +++ b/src/forge/data_models/loss.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +import torch +from forge.data_models.distributed_metric import Fraction +from forge.data_models.minibatch import Minibatch + + +@dataclass +class LossInput: + minibatch: Minibatch + trainer_logits: torch.Tensor + + +@dataclass +class LossOutput: + loss: Fraction diff --git a/src/forge/data_models/minibatch.py b/src/forge/data_models/minibatch.py new file mode 100644 index 000000000..1bcbc4ef7 --- /dev/null +++ b/src/forge/data_models/minibatch.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Sequence + +import torch +from forge.data_models.experience import Experience + + +@dataclass +class Minibatch: + """The minibatch that trainer will recieve.""" + + # The input sequence token ids for the trainer forward pass. + input_ids: torch.Tensor + + # The segment ids for the input sequence token ids. Same segment + # ids respresent the same sequence. + segment_ids: torch.Tensor + + # The targets required for loss computation, usually concatenated prompt and + # sample token ids. + target_ids: torch.Tensor + + # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. + target_mask: torch.Tensor + + # The weight to apply to the loss of each target token. It's normally computed + # from the advantage and the reward. + target_weights: torch.Tensor + + # The log probabilities of the target tokens, for prompt part it's set to 0, + # for generation part it's computed from the sampler. + target_log_probs: torch.Tensor + + +def from_experiences( + exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0 +) -> Minibatch: + """ + Convert a list of experiences to a minibatch. + """ + + def pack_sequence( + tensors: Sequence[torch.Tensor], + pad_val: Any, + dtype: torch.dtype, + max_len: int, + ) -> torch.Tensor: + """Packs multiple tensors along the seq dim.""" + seq = torch.cat(tensors) + pad_len = max_len - seq.size(0) + if pad_len < 0: + raise ValueError( + f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}" + ) + return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to( + dtype + ) + + mini_batch = {} + exp_list = defaultdict(list) + for i, exp in enumerate(exps): + input_ids = exp.ids[:-1] + exp_list["input_ids"].append(input_ids) + exp_list["target_ids"].append(exp.ids[1:]) + exp_list["segment_ids"].append(torch.ones_like(input_ids) * i) + exp_list["target_mask"].append(exp.mask[1:]) + exp_list["target_weights"].append(exp.weights[1:]) + exp_list["target_log_probs"].append(exp.log_probs[1:]) + + for k, v in exp_list.items(): + _dtype = torch.int64 + if k == "target_mask" or k == "target_weights" or k == "target_log_probs": + _dtype = torch.float32 + + mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len) + + return Minibatch(**mini_batch) diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py new file mode 100644 index 000000000..741b097aa --- /dev/null +++ b/src/forge/data_models/prompt.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + NONE = "none" + + +@dataclass +class Message: + """A single message in a conversation.""" + + chunks: Sequence[str] + role: Role + + +@dataclass +class Prompt: + """A multi-turn prompt (conversation history).""" + + # Multi-turn messages, each turn is a message. + messages: Sequence[Message] + metadata: Any | None = None + + @classmethod + def from_prompt( + cls, prompt: str, system_instruction: str | None = None + ) -> "Prompt": + messages = prompt_to_messages(prompt, system_instruction) + return Prompt( + messages=messages, + ) + + +def prompt_to_messages( + prompt: str, system_instruction: str | None = None +) -> Sequence[Message]: + """Convert a prompt to a sequence of messages.""" + messages = [] + if system_instruction is not None: + messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM)) + messages.append( + Message(chunks=[prompt], role=Role.USER), + ) + return messages + + +def to_prompt( + prompt: str, metadata: Any | None = None, system_instruction: str | None = None +) -> Prompt: + """Converts a prompt to a sequence of messages.""" + return Prompt( + messages=prompt_to_messages(prompt, system_instruction), + metadata=metadata, + ) diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py new file mode 100644 index 000000000..c1b41b8ae --- /dev/null +++ b/src/forge/data_models/scored_completion.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from forge.data_models.completion import Completion + + +@dataclass +class ScoredCompletion: + """A completion with an associated score (from a reward model or human).""" + + completion: Completion + score: float # akin to reward + + # TODO: add more fields as needed. diff --git a/src/forge/generators/__init__.py b/src/forge/generators/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/src/forge/generators/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/forge/generators/vllm_generator.py b/src/forge/generators/vllm_generator.py new file mode 100644 index 000000000..d5d3312a2 --- /dev/null +++ b/src/forge/generators/vllm_generator.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +from forge.data_models.api import Generator + +from forge.data_models.completion import Completion +from forge.data_models.prompt import Prompt + + +class VLLMGenerator(Generator): + def __init__(self, model_path: str): + self.model_path = model_path + + def generate(self, prompt: Prompt, **kwargs) -> List[Completion]: + """ + Generate completions for a given prompt using vLLM. + """ + return [] + + def update_weights( + self, weights_handle: dict[str, tuple[torch.Tensor, torch.dtype, torch.Size]] + ): + # TODO: NEEDS fixing: the weights_handle should be remote handle, like RDMA Buffer handle + return {} diff --git a/src/forge/trainers/__init__.py b/src/forge/trainers/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/src/forge/trainers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/forge/trainers/huggingface_trainer.py b/src/forge/trainers/huggingface_trainer.py new file mode 100644 index 000000000..d5865643e --- /dev/null +++ b/src/forge/trainers/huggingface_trainer.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Tuple + +import torch +from forge.data_models.api import Trainer + +from forge.data_models.distributed_metric import Fraction, SumDistributedMetric + +from forge.data_models.loss import LossOutput +from forge.data_models.minibatch import Minibatch + + +class HuggingFaceTrainer(Trainer): + def __init__(self, model_path: str): + super().__init__() + self.model_name = model_path + + def accummulate_gradients(self, minibatch: Minibatch) -> LossOutput: + return LossOutput( + loss=Fraction( + SumDistributedMetric(torch.Tensor(1)), SumDistributedMetric(1.0) + ) + ) + + def apply_gradients(self) -> None: + pass + + def snapshot_weights( + self, + ) -> Dict[str, Tuple[torch.Tensor, torch.dtype, torch.Size]]: + # TODO: NEEDS fixing: the weights should be remote handle, like RDMA Buffer handle + return {}