-
Notifications
You must be signed in to change notification settings - Fork 16
Adds TitanRefModel in place of HF based Reference Model #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/forge/actors/reference_actor.py
Outdated
self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) | ||
|
||
@endpoint | ||
async def forward(self, token_ids: list[int]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can keep things bs=1 for now, but we're going to need to add our own batching solution where we take n requests off the queue at a time and process them as a batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batching itself is an easy add using the queue in ReferenceActor (the one in the PR that's unused) or we can surface it as a common API in Replica (Just processes_single_request
atm)
if request_output.finished: | ||
_, fut = self.requests.pop(request_output.request_id) | ||
fut.set_result(request_output.outputs) | ||
fut.set_result(request_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adopted from #97
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this instead of raw outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pragmatically: Less merge conflict with Philip's PR
I don't have strong preference, but it does make the output self contained which is nice when we need to pass the results around
from .replay_buffer import ReplayBuffer | ||
|
||
return ReplayBuffer | ||
elif name == "TitanRefModel": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙃
if request_output.finished: | ||
_, fut = self.requests.pop(request_output.request_id) | ||
fut.set_result(request_output.outputs) | ||
fut.set_result(request_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this instead of raw outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Post land review
|
||
|
||
@dataclass | ||
class TitanRefModel(ForgeActor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On naming: I would just call this ReferenceModel, we don't use Titan anywhere else and we havn't generalized it yet
|
||
# Refer to titan JobConfig for enabling more ForgeEngine configuration | ||
model: Model = field(default_factory=Model) | ||
parallelism: Parallelism = field(default_factory=Parallelism) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we would need Checkpoint and Compile too. Or is checkpoint not needed for loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed by default, Checkpointing is used when we want to write out or load from a checkpoint
Former doesn't happen with the reference model, the latter does unlock custom ref models (or using checkpoints as ref models)
if parallel_dims.pp_enabled: | ||
raise NotImplementedError("PP not implemented yet") | ||
else: | ||
# (jackkhuu) Not sure if either context are needed for inference here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you separate out the request call from forward? Also look at the updated RefActor in main to make sure they're in sync
return sequence_log_probs | ||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we comment out everything below this?
) # Remove batch dimension for single response | ||
|
||
|
||
def compute_sequence_logprobs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should update this to match the one from main. This should probably go in the trainer.py file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copy paste from main (at the time Joe's)
compute_logprobs
is the file is the implementation from #97
Follow up changes in #118 |
Replaces the existing HF based RefModel with a torchtitan backed
TitanRefModel
Note: This PR persists the old implementation using
AutoModelForCausalLM
until GRPO is ready for full migration to torchtitanpython apps/grpo/main.py
Example Output