-
Notifications
You must be signed in to change notification settings - Fork 24
[1/N] Core Data Models #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
from forge.data_models.prompt import Prompt | ||
|
||
|
||
@dataclass | ||
class Completion: | ||
"""A model-generated completion for a given prompt.""" | ||
|
||
# The original prompt. | ||
prompt: Prompt | ||
|
||
# the decoded text returned by the model | ||
text: str | ||
|
||
# the encoded text (token ids) that were fed into the model | ||
prompt_ids: torch.Tensor | ||
|
||
# the encoded text (token ids) that were generated by the model | ||
token_ids: torch.Tensor | ||
|
||
# the log probabilities of the target tokens | ||
log_probs: Optional[torch.Tensor] = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
@dataclass | ||
class DistributedMetric(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this concept and it could be combined with our MetricLogger. But if we're going to handle this in a generic way, I don't think we want to deal with process groups, otherwise you have to set them up with a bunch of services that don't need them. Either we should use these abstractions and return metrics to a MetricService/controller to aggregate or handle it without formal abstractions, in a service specific way, as we do now. |
||
"""Metrics that are calculated in distributed fashion. | ||
|
||
Metrics computed in each rank are going to be wrapped in DistributedMetric | ||
according to how they are going to be aggregated. For example, average log prob | ||
can be wrapped as `Fraction(Sum((logp * mask).sum()), Sum(mask.sum()))` where | ||
`mask` indicates which token is valid. | ||
""" | ||
|
||
# We need to pass a context argument for distribution setup in the future. | ||
@abstractmethod | ||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def local(self) -> torch.Tensor: | ||
pass | ||
|
||
|
||
@dataclass | ||
class SumDistributedMetric(DistributedMetric): | ||
def __init__(self, tensor: torch.Tensor) -> None: | ||
self.tensor = tensor | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return _try_clone_and_reduce(self.tensor, op=dist.ReduceOp.SUM, group=group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.tensor | ||
|
||
|
||
@dataclass | ||
class Fraction: | ||
numerator: DistributedMetric | ||
denominator: DistributedMetric | ||
|
||
def reduce(self, group: dist.ProcessGroup | None = None) -> torch.Tensor: | ||
return self.numerator.reduce(group) / self.denominator.reduce(group) | ||
|
||
def local(self) -> torch.Tensor: | ||
return self.numerator.local() / self.denominator.local() | ||
|
||
|
||
def _try_clone_and_reduce( | ||
tensor: torch.Tensor, op: dist.ReduceOp, group: dist.ProcessGroup | None | ||
) -> torch.Tensor: | ||
cloned = tensor.detach().clone() | ||
if dist.is_initialized(): | ||
dist.all_reduce(cloned, op=op, group=group) | ||
return cloned |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from forge.data_models.scored_completion import ScoredCompletion | ||
|
||
|
||
@dataclass | ||
class Experience: | ||
""" | ||
The Experience data class to be used by the trainer. | ||
|
||
Experiences are usually generated from a scored completion and running various post processing steps. | ||
""" | ||
|
||
# Concatenated prompt and sample token ids. | ||
ids: torch.Tensor | ||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the Generator/Sampler. | ||
log_probs: Optional[torch.Tensor] = None | ||
|
||
# TODO: add more fields as required | ||
state: str = "" | ||
|
||
|
||
def from_scored_completion(scored_completion: ScoredCompletion) -> Experience: | ||
"""Converts a ScoredCompletion to an Experience.""" | ||
prompt_ids = scored_completion.completion.prompt_ids | ||
token_ids = scored_completion.completion.token_ids | ||
log_probs = scored_completion.completion.log_probs | ||
ids = torch.cat([prompt_ids, token_ids]) | ||
mask = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
torch.ones_like(token_ids, dtype=torch.float32), | ||
] | ||
) | ||
advantage = scored_completion.score | ||
weights = mask * advantage | ||
log_probs = torch.cat( | ||
[ | ||
torch.zeros(prompt_ids.shape, dtype=torch.float32), | ||
# TODO: this only works if sample.log_probs is 1 | ||
log_probs, | ||
] | ||
) | ||
return Experience(ids=ids, mask=mask, weights=weights, log_probs=log_probs) | ||
|
||
|
||
def from_scored_completions( | ||
scored_completions: Sequence[ScoredCompletion], | ||
) -> Sequence[Experience]: | ||
"""Converts a sequence of ScoredCompletion to a sequence of Experiences.""" | ||
return [from_scored_completion(sc) for sc in scored_completions] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
from forge.data_models.distributed_metric import Fraction | ||
from forge.data_models.minibatch import Minibatch | ||
|
||
|
||
@dataclass | ||
class LossInput: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is roughly what we're doing now, except it's trainer_logits + target_minibatch which is a subset of the minibatch you routed for the loss. |
||
minibatch: Minibatch | ||
trainer_logits: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class LossOutput: | ||
loss: Fraction |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from typing import Any, Sequence | ||
|
||
import torch | ||
from forge.data_models.experience import Experience | ||
|
||
|
||
@dataclass | ||
class Minibatch: | ||
"""The minibatch that trainer will recieve.""" | ||
|
||
# The input sequence token ids for the trainer forward pass. | ||
input_ids: torch.Tensor | ||
|
||
# The segment ids for the input sequence token ids. Same segment | ||
# ids respresent the same sequence. | ||
segment_ids: torch.Tensor | ||
|
||
# The targets required for loss computation, usually concatenated prompt and | ||
# sample token ids. | ||
target_ids: torch.Tensor | ||
|
||
# The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. | ||
target_mask: torch.Tensor | ||
|
||
# The weight to apply to the loss of each target token. It's normally computed | ||
# from the advantage and the reward. | ||
target_weights: torch.Tensor | ||
|
||
# The log probabilities of the target tokens, for prompt part it's set to 0, | ||
# for generation part it's computed from the sampler. | ||
target_log_probs: torch.Tensor | ||
|
||
|
||
def from_experiences( | ||
exps: Sequence[Experience], max_seq_len: int, pad_val: int = 0 | ||
) -> Minibatch: | ||
""" | ||
Convert a list of experiences to a minibatch. | ||
""" | ||
|
||
def pack_sequence( | ||
tensors: Sequence[torch.Tensor], | ||
pad_val: Any, | ||
dtype: torch.dtype, | ||
max_len: int, | ||
) -> torch.Tensor: | ||
"""Packs multiple tensors along the seq dim.""" | ||
seq = torch.cat(tensors) | ||
pad_len = max_len - seq.size(0) | ||
if pad_len < 0: | ||
raise ValueError( | ||
f"Sequence lenth {seq.size(0)} exceeds the maximum length {max_len}" | ||
) | ||
return torch.nn.functional.pad(seq, (0, pad_len), value=pad_val)[None, ...].to( | ||
dtype | ||
) | ||
|
||
mini_batch = {} | ||
exp_list = defaultdict(list) | ||
for i, exp in enumerate(exps): | ||
input_ids = exp.ids[:-1] | ||
exp_list["input_ids"].append(input_ids) | ||
exp_list["target_ids"].append(exp.ids[1:]) | ||
exp_list["segment_ids"].append(torch.ones_like(input_ids) * i) | ||
exp_list["target_mask"].append(exp.mask[1:]) | ||
exp_list["target_weights"].append(exp.weights[1:]) | ||
exp_list["target_log_probs"].append(exp.log_probs[1:]) | ||
|
||
for k, v in exp_list.items(): | ||
_dtype = torch.int64 | ||
if k == "target_mask" or k == "target_weights" or k == "target_log_probs": | ||
_dtype = torch.float32 | ||
|
||
mini_batch[k] = pack_sequence(v, pad_val, _dtype, max_seq_len) | ||
|
||
return Minibatch(**mini_batch) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from collections.abc import Sequence | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any | ||
|
||
|
||
class Role(Enum): | ||
SYSTEM = "system" | ||
USER = "user" | ||
ASSISTANT = "assistant" | ||
NONE = "none" | ||
|
||
|
||
@dataclass | ||
class Message: | ||
"""A single message in a conversation.""" | ||
|
||
chunks: Sequence[str] | ||
role: Role | ||
|
||
|
||
@dataclass | ||
class Prompt: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general we should match OpenAI requests here as the type. This way we can use the same formatting for local and API judge calls |
||
"""A multi-turn prompt (conversation history).""" | ||
|
||
# Multi-turn messages, each turn is a message. | ||
messages: Sequence[Message] | ||
metadata: Any | None = None | ||
|
||
@classmethod | ||
def from_prompt( | ||
cls, prompt: str, system_instruction: str | None = None | ||
) -> "Prompt": | ||
messages = prompt_to_messages(prompt, system_instruction) | ||
return Prompt( | ||
messages=messages, | ||
) | ||
|
||
|
||
def prompt_to_messages( | ||
prompt: str, system_instruction: str | None = None | ||
) -> Sequence[Message]: | ||
"""Convert a prompt to a sequence of messages.""" | ||
messages = [] | ||
if system_instruction is not None: | ||
messages.append(Message(chunks=[system_instruction], role=Role.SYSTEM)) | ||
messages.append( | ||
Message(chunks=[prompt], role=Role.USER), | ||
) | ||
return messages | ||
|
||
|
||
def to_prompt( | ||
prompt: str, metadata: Any | None = None, system_instruction: str | None = None | ||
) -> Prompt: | ||
"""Converts a prompt to a sequence of messages.""" | ||
return Prompt( | ||
messages=prompt_to_messages(prompt, system_instruction), | ||
metadata=metadata, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
|
||
from forge.data_models.completion import Completion | ||
|
||
|
||
@dataclass | ||
class ScoredCompletion: | ||
"""A completion with an associated score (from a reward model or human).""" | ||
|
||
completion: Completion | ||
score: float # akin to reward | ||
|
||
# TODO: add more fields as needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently our Episode is a combination of Completion and ScoreCompletion and Experience. I think keeping them flat makes customization and logging a bit easier but either way can work.