-
Notifications
You must be signed in to change notification settings - Fork 24
Weight loading working correctly with tp: use vllm builtin load_weights() #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
logger.debug(f"Starting weight update on {self.__class__.__name__}") | ||
await self.policy_worker.update.call(version=policy_version) | ||
if self.use_vllm_builtin_load: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, this will be the default right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like the plan
src/forge/actors/policy.py
Outdated
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def _update_hf_nonsharded(self, version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this specific to hf??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this just means we are pushing/reading the state dict using the hugging face format. not titan, not vllm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps call it update_DEPRECATED
and update
. I'd like to keep the DEPRECATED
one just for A/B testing and delete it before the PTC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the choice between using get_state_dict/get_state_dict
and the get/put
API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also confused at the load_weights
API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also confused at the
load_weights
API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?
every worker(rank) has to call load_weights
.
After load_weights()
is called, every worker will figure out its own rank and just read its own shard when load_weights()
is called.
Moreover, load_weights()
supports incremental updating, i.e., if there is only one tensor in the passed in weights, it will update that part specifically (it even handles these concatenated weights as well).
For example, if you pass in (I am making up the fqn but you get the point) a single kv pair "model.layers.0.q_proj.xxx" -> full_tensor
, it will actually update the q_proj
part of the fused qkv_proj
weight.
Is the source of the problem also because we have deconstructed the vllm engine and using it piecemeal instead of using it as a whole? |
apps/grpo/main.py
Outdated
mlogger.log( | ||
"push_weights_time/training_step", | ||
time.perf_counter() - start_time, | ||
training_step, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I think let's split this diff to
- add weight sync counter
- add options to do per-tensor weight sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this for debugging myself, will do!
|
||
@endpoint | ||
async def update_weights(self): | ||
async def update_weights(self, policy_version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want to rebase on this #181
I'll address the comments and merge the PR ASAP
src/forge/actors/trainer.py
Outdated
else: | ||
await self._push_weights_sharded(policy_version) | ||
|
||
async def _push_weights_sharded(self, policy_version: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some confusion: in my fix, I feel that this is not sharded. The difference is whether to process the state dict or not. Basically we just need to skip this path _qwen3_hf_to_vllm
and keep the rest as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, I am bad at naming things. it actually has nothing to do with sharding at this point.
maybe we just call this push_weights_vllm
vs push_weights_hf
(or push_weights_DEPRECATED
vs push_weights
if you will)
src/forge/actors/policy.py
Outdated
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def _update_hf_nonsharded(self, version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps call it update_DEPRECATED
and update
. I'd like to keep the DEPRECATED
one just for A/B testing and delete it before the PTC.
src/forge/actors/policy.py
Outdated
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def _update_hf_nonsharded(self, version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the choice between using get_state_dict/get_state_dict
and the get/put
API?
src/forge/actors/policy.py
Outdated
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds") | ||
|
||
@endpoint | ||
async def _update_hf_nonsharded(self, version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also confused at the load_weights
API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?
|
What we are doing here is we are using vllm workers and let monarch handles the collective operations instead of vllm's default method. That said, we can still use
|
apps/toy_rl/sumdigits.py
Outdated
|
||
@endpoint | ||
async def push_weights(self, version: int): | ||
async def push_weights_DEPRECATED(self, policy_version: int): # noqa: N802 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're confident in this fix, we should just fully delete the old way. My thinking is as follows:
- Gets everyone immediately testing the new version for any bugs 👍
- Reduces the chance an end user sees and uses this endpoint 👍
- Less code to parse through right now 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gets everyone immediately testing the new version for any bugs 👍
Yes the new one is the default now! I think the plan is keep the DEPRECATED method just for benchmarking purposes now? @JenniferWang
# Instead, we just call load_weights with one parameter at a time. | ||
for name in hf_names: | ||
param = await ts.get(get_param_key(version, name)) | ||
loaded = model.load_weights([(name, param)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super cool! I didn't realize you could do it per-param :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it's surprisingly good
src/forge/actors/policy.py
Outdated
loaded = model.load_weights([(name, param)]) | ||
del param | ||
loaded_weights.update(loaded) | ||
self.logger.info(f"Updated {len(loaded_weights)} parameters") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I prefer the old debug message that prints out the time it took to update the weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will add it back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable!
apps/grpo/main.py
Outdated
@parse | ||
def _main(cfg): | ||
asyncio.run(main(cfg)) | ||
with TemporaryDirectory(prefix="forge_run_", dir="/dev/shm") as dcp_path: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this won't work lemme revert
Add a flag use_vllm_bultin_load to policy and trainer.

When set to true, will use vllm builtin load_weights() method to exchange weights.
In particular, this works correctly with TP.
Tested with sumdigits example tp_size = 2.
https://meta.wandb.io/torchforge/sumdigits-training/runs/f9jb060e/panel/2pa6e8ptg?nw=nwuseryuxuanh
avg_reward >0.9 in 3k steps