-
Notifications
You must be signed in to change notification settings - Fork 24
[3/N] Core generator abstraction #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
from forge.stores.in_memory_store import InMemoryStore | ||
|
||
|
||
class HuggingFaceTrainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first glance, I am a bit skeptical that having multiple Trainers would work in practice. Is the idea that there would be a HF and a Titan trainer? Could you share a bit about why we need the Trainer abstraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes providing an optionality for the end-users. if they want to compare two runs or set bench marks or just want to try them out.
The idea is to define the abstraction and enable the policy and learner actors to inject an implementation of the trainer. Essentially what trainer/generator to use can be provided from config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't they have significantly different config signatures / distributed APIs / library requirements, etc? Do you see it working for a V0 of Forge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they would sure have different configs and we need have a config abstraction too.
Do you see it working for a V0 of Forge?
the current apps(grpo) use HF, and I just saw another PR to make the grpo app running for titan. so we already have two variations of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 , do you know if we will support both Titan and HF after we work on Titan integration? At least in torchtune finding all the bugs/optimizations/features for a single framework was a lot of work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a trainer abstraction feels like a huge benefit -- users may have lots of different reasons for trying different trainers and they should occupy similar shapes
class Completion: | ||
"""A model-generated completion for a given prompt.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we already have a dataclass for this. Not sure if its "Episode". But it makes sense to have a completion class.
|
||
|
||
@dataclass | ||
class DistributedMetric(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use https://github.com/meta-pytorch/forge/tree/main/src/forge/data/dataset_metrics
But lets connect and see if you think it makes sense!
|
||
|
||
@dataclass | ||
class Experience: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very similar to Completion. I wonder if we should just keep one. But its also a bit similar to "Episode".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to see how this plays with data packing. It could be that the Packer takes list[Experience]. But i dont think it makes sense for this class to have concat logic if we are having packing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is modeled in a way that, Completion is what we receive from the generator and the experience is what we feed into trainer (in-between it goes through scoring, post-processing, etc.).
@dataclass | ||
class LossInput: | ||
minibatch: Minibatch | ||
trainer_logits: torch.Tensor | ||
|
||
|
||
@dataclass | ||
class LossOutput: | ||
loss: Fraction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its interesting to have a LossInput and LossOutput class. I am just afraid that these abstractions would add too much hierarchy. i.e. instead of loss(logits, targets, mask)
, we have loss(LossInput(MiniBatch(something_else)))
.
But the LossOutput makes more sense to me because we may be outputting multiple numbers. We should double check, because if its only a couple, then having a dataclass might be an overkill
Convert a list of experiences to a minibatch. | ||
""" | ||
|
||
def pack_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we will have to revisit it after we do dataset packing. My idea is to have PackedDataset(iterator), where the iterator is a replay buffer. The replay buffer outputs an Experience, and outputs a PackedMiniBatch (or something like that)
class Message: | ||
"""A single message in a conversation.""" | ||
|
||
chunks: Sequence[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats a chunk? Is it for example 2 consecutive messages of an user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the inital ideas was mark each chunk as trainable (mask). I probably complicated this by. we can just use str for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least in torchtune, we needed different types of messages/prompt handling. This makes sense to me for a V0, but i wonder if users will want to do more with this: https://github.com/pytorch/torchtune/blob/main/torchtune/data/_messages.py
Maybe @joecummings has some thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I thought I left a TODO that we need support other formats like image (bytes) etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I would endorse all or most of these interfaces as useful abstractions for forge.
In my personal experience it's a bit easier to reason about the interfaces once they are integrated, so we can see e2e usage. I tend to lean on the side of only adding interfaces once they're used, but curious if you have an integration plan?
# TODO: This file needs should NOT be in the data_models folder/package | ||
|
||
|
||
class Store(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really close to the KVStore we're building in https://github.com/meta-pytorch/forge/pull/147
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack. Will rebase and remove this once the other PR gets merged.
pass | ||
|
||
|
||
class WeightsBuffer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a reasonable abstraction in forge, although I do wonder if this is something torchstore should support to begin with?
e.g. (e.g., in-memory, RDMA, file system, torchstore etc.)
These are all things torchstore should support OOB.
return self.store.get(key) | ||
|
||
|
||
class Trainer(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
from forge.data_models.api import Store | ||
|
||
|
||
class InMemoryStore(Store): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the second PR I see an in memory store implementation, and a bit concerned that this boiler plate code is being generated as a result of torchstore!
Do you think it'd be a good idea to add a 'single process in memory' version of this to torchstore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this PR I reviewed Generator and Store. I think having light interfaces is good for these. I left a discussion point around Generator.update_weights but otherwise this makes sense. Two points of discussion:
- How flexible are we with updating these as we need to?
- Should we merge these interfaces with the service interfaces like setup/launch etc so users can see the full interface in one spot instead of having to look through multiple interface inheritance?
""" | ||
return [] | ||
|
||
def update_weights(self, weights_handle: WeightsBuffer) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we have update_weights in Forge as well but I don't think it's compatible with efficient weight updates when the policy is using a continuous batching system. Essentially vLLM decodes a fixed number of tokens per step, it then removes completed requests from the stack and adds new request to fill the stack back up. So the batch is never done. This means that you either have to exhaust the request queue before calling update, which isn't very efficient but gives you full control, or you have to update mid decoding.
Ideally you want to be able to either update as soon as you can (e.g. after the weights have been pre-fetched) or you want want to call generate with a policy request (policy.generate(prompt, policy_version) and then update the policy as soon as all the old policy requests are gone. You would also want to do this at different times for every replica in the service as their queue's would run out at different times.
All of this to say, this logic would be much easier to internalize within the generator. If we put it in the controller, we'll have to have a separate loop that reads a bunch of state information about the generator and calls this update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the input parameter would likely be the policy version or key in store
from forge.data_models.prompt import Prompt | ||
|
||
|
||
class VLLMGenerator(Generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this RFC just define the interfaces? So instead of VLLMGenerator it would just be a Generator abstract class.
from forge.data_models.api import Store | ||
|
||
|
||
class InMemoryStore(Store): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this one is worth adding an abstraction for as torchstore is as fundamental to our stack as monarch. We'd essentially just be updating the interface to match store's interface.
A split version of this RFC PR