diff --git a/src/forge/data_models/__init__.py b/src/forge/data_models/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/src/forge/data_models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py new file mode 100644 index 000000000..eca4f62fe --- /dev/null +++ b/src/forge/data_models/completion.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from forge.data_models.prompt import Prompt + + +@dataclass +class Completion: + """A model-generated completion for a given prompt.""" + + # The original prompt. + prompt: Prompt + + # the decoded text returned by the model + text: str + + # the encoded text (token ids) that were fed into the model + prompt_ids: torch.Tensor + + # the encoded text (token ids) that were generated by the model + token_ids: torch.Tensor + + # the log probabilities of the target tokens + log_probs: Optional[torch.Tensor] = None diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py new file mode 100644 index 000000000..5df2352ab --- /dev/null +++ b/src/forge/data_models/episode.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +from forge.data_models.scored_completion import ScoredCompletion + + +@dataclass +class Episode: + """ + The Episode data class to be used by the trainer. + + Episodes are usually generated from a scored completion and running various post processing steps. + """ + + # Concatenated prompt and sample token ids. + ids: torch.Tensor + + # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. + mask: torch.Tensor + + # The weight to apply to the loss of each target token. It's normally computed + # from the advantage and the reward. + weights: torch.Tensor + + # The log probabilities of the target tokens, for prompt part it's set to 0, + # for generation part it's computed from the Generator/Sampler. + log_probs: Optional[torch.Tensor] = None + + # TODO: add more fields as required + state: str = "" + + +def from_scored_completion(scored_completion: ScoredCompletion) -> Episode: + """Converts a ScoredCompletion to an Episode.""" + prompt_ids = scored_completion.completion.prompt_ids + token_ids = scored_completion.completion.token_ids + log_probs = scored_completion.completion.log_probs + ids = torch.cat([prompt_ids, token_ids]) + mask = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + torch.ones_like(token_ids, dtype=torch.float32), + ] + ) + advantage = scored_completion.score + weights = mask * advantage + log_probs = torch.cat( + [ + torch.zeros(prompt_ids.shape, dtype=torch.float32), + # TODO: this only works if sample.log_probs is 1 + log_probs, + ] + ) + return Episode(ids=ids, mask=mask, weights=weights, log_probs=log_probs) + + +def from_scored_completions( + scored_completions: Sequence[ScoredCompletion], +) -> Sequence[Episode]: + """Converts a sequence of ScoredCompletion to a sequence of Episodes.""" + return [from_scored_completion(sc) for sc in scored_completions] diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py new file mode 100644 index 000000000..55f538c0e --- /dev/null +++ b/src/forge/data_models/prompt.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + NONE = "none" + + +@dataclass +class Message: + """A single message in a conversation.""" + + chunks: Sequence[str] + role: Role + + +@dataclass +class Prompt: + """A multi-turn prompt (conversation history).""" + + # Multi-turn messages, each turn is a message. + messages: Sequence[Message] + + @classmethod + def from_prompt( + cls, prompt: str, system_instruction: str | None = None + ) -> "Prompt": + messages = prompt_to_messages(prompt, system_instruction) + return Prompt( + messages=messages, + ) + + +def prompt_to_messages( + prompt: str, system_instruction: str | None = None +) -> Sequence[Message]: + """Convert a prompt to a sequence of messages.""" + messages = [] + if system_instruction is not None: + messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM)) + messages.append( + Message(chunks=[prompt], role=Role.USER), + ) + return messages + + +def to_prompt(prompt: str, system_instruction: str | None = None) -> Prompt: + """Converts a prompt to a sequence of messages.""" + return Prompt( + messages=prompt_to_messages(prompt, system_instruction), + ) diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py new file mode 100644 index 000000000..f41ff7b59 --- /dev/null +++ b/src/forge/data_models/scored_completion.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from forge.data_models.completion import Completion + + +@dataclass +class ScoredCompletion: + """A completion with an associated score (from a reward model or human).""" + + completion: Completion + score: float # akin to reward + + # TODO: add more fields as needed.