-
Notifications
You must be signed in to change notification settings - Fork 17
Publishing weights in to torchstore from RLTrainer and getting them from policy engine. #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pradeepfn please clarify, does put hang or throw? |
No torchstore issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really great! Glad we can finally have this started. Mostly left questions and a few nits.
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this coming from? When you call this, does it create the sd right then or did it have to be saved in the train step earlier? Does it return the sd on GPU or CPU? Also does it handle blocking the trainer from updating the weights while it's getting them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm accessing the module state-dict prepped by torch.titan as part of checkpoint save.
- This is a in-memory state-dict. ( Tensor/DTensor).
- It returns tensors with original storage. Means GPU/UVM backed tensors.
Also does it handle blocking the trainer from updating the weights while it's getting them?
Hmm.. it does not block the trainer. However, ForgeEngine drive the trainer using train_step. Therefore, there is no race-conditions with current code.
There is improvements to be made to this code. In the ideal case;
- the state-dict get prepped for weight-exchange and checkpoint save purposes.
- Once the initial state-dict prep we can cache the prepped state-dict for later iterations of the training steps for efficiency reasons ( if there is opportunity).
- We move all the model weights and optimizer state to torchstore.
- Policy engine (only) lookup the model-weights from torchstore
- Async checkpointing upload lookups model-weights and optimizer states for uploading in to remote persistent storage.
We don't have all the piece right now. But tapping in to checkpoint state-dict is the right thing to do as the first step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you're right that it should be mostly safe since we control the update from the controller. But since they're async calls they could be overlapped so we'll have to be careful for now.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome!
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would "model"
not be in the self.engine.checkpointer.states
? In other words, can we update the assertion error to be more informative? Does this fail if the user didn't initialize the trainer properly/what do they need to do to make it work??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially, this only happens if the checkpoint_manager of torchtitan is not initialized prior to calling push_weights routine. I can update the error message (followup PR) with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this new rtol
/atol
expected? question for @pbontrager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never great, but given the bf16/fp16 comments I could see that. This is also an allclose and not a comparison of the mean so we should be safe here. If we can load the hf side with bf16 instead of fp16 we might be able to regain the tighter tolerance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should hardcode the trainer config here rather than load from apps/rl/llama3_8b.yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving this with a few nits
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you're right that it should be mostly safe since we control the update from the controller. But since they're async calls they could be overlapped so we'll have to be careful for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never great, but given the bf16/fp16 comments I could see that. This is also an allclose and not a comparison of the mean so we should be safe here. If we can load the hf side with bf16 instead of fp16 we might be able to regain the tighter tolerance.
FYI: @joecummings is working on merging this diff. This is because of recent API changes that was pushed by him. @joecummings can add more insights. thanks. |
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: are there benfits to doing this at the state dict level and not at the key level where we could parallelize the individual put operation per key?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, the benefit is simplicity.
Eventually I imagine we will want to do this on a per-slice level.
cc @LucasLLC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamping. Only note is to remember to update the Policy or the controller to remove old policies from the store once all generators are updated.
src/forge/actors/trainer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: what is the reasoning for doing this at the learner and not at the generator? The trainer just pushes it weights and the generator can based on it's implementation (vLLM, sglang etc.) can modify the sd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an temp thing I did. It will be moved to generator sd loading + it will be moved out of the trainer/generator critical path based on efficiency numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add to Pradeep's answer, vLLM already handles it's own hf -> vllm mapping. The only reason we've recreated it is so we can add a shaded loading solution which we want to eventually upstream. It will be on the generator side like he said.
Working integration test that works with Llama8B.
State-dict conversion to HF format happens during weight-save/publish in to TS using sd_adaptor (mostly). Custome param concatanations that are not supported by the current sd_adaptor impl of the model was implemented directly within the weights publish function.