Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/forge/data_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
31 changes: 31 additions & 0 deletions src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

import torch
from forge.data_models.prompt import Prompt


@dataclass
class Completion:
"""A model-generated completion for a given prompt."""

# The original prompt.
prompt: Prompt

# the decoded text returned by the model
text: str

# the encoded text (token ids) that were fed into the model
prompt_ids: torch.Tensor

# the encoded text (token ids) that were generated by the model
token_ids: torch.Tensor

# the log probabilities of the target tokens
log_probs: Optional[torch.Tensor] = None
68 changes: 68 additions & 0 deletions src/forge/data_models/episode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence

import torch
from forge.data_models.scored_completion import ScoredCompletion


@dataclass
class Episode:
"""
The Episode data class to be used by the trainer.

Episodes are usually generated from a scored completion and running various post processing steps.
"""

# Concatenated prompt and sample token ids.
ids: torch.Tensor

# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens.
mask: torch.Tensor

# The weight to apply to the loss of each target token. It's normally computed
# from the advantage and the reward.
weights: torch.Tensor

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None

# TODO: add more fields as required
state: str = ""


def from_scored_completion(scored_completion: ScoredCompletion) -> Episode:
"""Converts a ScoredCompletion to an Episode."""
prompt_ids = scored_completion.completion.prompt_ids
token_ids = scored_completion.completion.token_ids
log_probs = scored_completion.completion.log_probs
ids = torch.cat([prompt_ids, token_ids])
mask = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
torch.ones_like(token_ids, dtype=torch.float32),
]
)
advantage = scored_completion.score
weights = mask * advantage
log_probs = torch.cat(
[
torch.zeros(prompt_ids.shape, dtype=torch.float32),
# TODO: this only works if sample.log_probs is 1
log_probs,
]
)
return Episode(ids=ids, mask=mask, weights=weights, log_probs=log_probs)


def from_scored_completions(
scored_completions: Sequence[ScoredCompletion],
) -> Sequence[Episode]:
"""Converts a sequence of ScoredCompletion to a sequence of Episodes."""
return [from_scored_completion(sc) for sc in scored_completions]
62 changes: 62 additions & 0 deletions src/forge/data_models/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
NONE = "none"


@dataclass
class Message:
"""A single message in a conversation."""

chunks: Sequence[str]
role: Role


@dataclass
class Prompt:
"""A multi-turn prompt (conversation history)."""

# Multi-turn messages, each turn is a message.
messages: Sequence[Message]

@classmethod
def from_prompt(
cls, prompt: str, system_instruction: str | None = None
) -> "Prompt":
messages = prompt_to_messages(prompt, system_instruction)
return Prompt(
messages=messages,
)


def prompt_to_messages(
prompt: str, system_instruction: str | None = None
) -> Sequence[Message]:
"""Convert a prompt to a sequence of messages."""
messages = []
if system_instruction is not None:
messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM))
messages.append(
Message(chunks=[prompt], role=Role.USER),
)
return messages


def to_prompt(prompt: str, system_instruction: str | None = None) -> Prompt:
"""Converts a prompt to a sequence of messages."""
return Prompt(
messages=prompt_to_messages(prompt, system_instruction),
)
19 changes: 19 additions & 0 deletions src/forge/data_models/scored_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

from forge.data_models.completion import Completion


@dataclass
class ScoredCompletion:
"""A completion with an associated score (from a reward model or human)."""

completion: Completion
score: float # akin to reward

# TODO: add more fields as needed.
Loading