Skip to content

Conversation

casteryh
Copy link
Contributor

@casteryh casteryh commented Sep 19, 2025

Add a flag use_vllm_bultin_load to policy and trainer.
When set to true, will use vllm builtin load_weights() method to exchange weights.
In particular, this works correctly with TP.
Tested with sumdigits example tp_size = 2.
https://meta.wandb.io/torchforge/sumdigits-training/runs/f9jb060e/panel/2pa6e8ptg?nw=nwuseryuxuanh
avg_reward >0.9 in 3k steps
W B Chart 9_23_2025, 10_22_37 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 19, 2025
@casteryh casteryh changed the title [WIP][not for land] use vllm load_weights() in GRPO [WIP] use vllm builtin load_weights Sep 23, 2025
@casteryh casteryh changed the title [WIP] use vllm builtin load_weights [WIP] vllm builtin load_weights() Sep 23, 2025
@casteryh casteryh changed the title [WIP] vllm builtin load_weights() Weight loading working correctly with tp: use vllm builtin load_weights() Sep 24, 2025
@casteryh casteryh marked this pull request as ready for review September 24, 2025 05:49

logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=policy_version)
if self.use_vllm_builtin_load:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, this will be the default right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the plan

logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def _update_hf_nonsharded(self, version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this specific to hf??

Copy link
Contributor Author

@casteryh casteryh Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this just means we are pushing/reading the state dict using the hugging face format. not titan, not vllm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps call it update_DEPRECATED and update. I'd like to keep the DEPRECATED one just for A/B testing and delete it before the PTC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the choice between using get_state_dict/get_state_dict and the get/put API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also confused at the load_weights API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also confused at the load_weights API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?

every worker(rank) has to call load_weights.
After load_weights() is called, every worker will figure out its own rank and just read its own shard when load_weights() is called.

Moreover, load_weights() supports incremental updating, i.e., if there is only one tensor in the passed in weights, it will update that part specifically (it even handles these concatenated weights as well).
For example, if you pass in (I am making up the fqn but you get the point) a single kv pair "model.layers.0.q_proj.xxx" -> full_tensor, it will actually update the q_proj part of the fused qkv_proj weight.

@vidhyav
Copy link
Contributor

vidhyav commented Sep 24, 2025

Is the source of the problem also because we have deconstructed the vllm engine and using it piecemeal instead of using it as a whole?

Comment on lines 410 to 414
mlogger.log(
"push_weights_time/training_step",
time.perf_counter() - start_time,
training_step,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I think let's split this diff to

  1. add weight sync counter
  2. add options to do per-tensor weight sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this for debugging myself, will do!


@endpoint
async def update_weights(self):
async def update_weights(self, policy_version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to rebase on this #181
I'll address the comments and merge the PR ASAP

else:
await self._push_weights_sharded(policy_version)

async def _push_weights_sharded(self, policy_version: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some confusion: in my fix, I feel that this is not sharded. The difference is whether to process the state dict or not. Basically we just need to skip this path _qwen3_hf_to_vllm and keep the rest as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I am bad at naming things. it actually has nothing to do with sharding at this point.
maybe we just call this push_weights_vllm vs push_weights_hf (or push_weights_DEPRECATED vs push_weights if you will)

logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def _update_hf_nonsharded(self, version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps call it update_DEPRECATED and update. I'd like to keep the DEPRECATED one just for A/B testing and delete it before the PTC.

logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def _update_hf_nonsharded(self, version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the choice between using get_state_dict/get_state_dict and the get/put API?

logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def _update_hf_nonsharded(self, version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also confused at the load_weights API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?

@casteryh
Copy link
Contributor Author

casteryh commented Sep 26, 2025

@JenniferWang

Could you explain the choice between using get_state_dict/get_state_dict and the get/put API?

get_state_dict will get all the weights all at once which we probably don't want.
I don't have problems with put_state_dict, but the thing is if I am not using get_state_dict, I am not sure how to properly read something written in with put_state_dict.
I'd rather have a flat kv structure and everything I can control myself.
Also, the torchstore api explicitly says get_state_dict and put_state_dict are for testing purposes (last time I checked the codebase). I got no idea why it's used here in the first place.

@casteryh
Copy link
Contributor Author

casteryh commented Sep 26, 2025

@vidhyav

Is the source of the problem also because we have deconstructed the vllm engine and using it piecemeal instead of using it as a whole?

What we are doing here is we are using vllm workers and let monarch handles the collective operations instead of vllm's default method. That said, we can still use load_weights() method on vllm workers.

  • Ordinary vllm weight loading: vllm using its own collectives to call load_weights() on each worker.
  • This PR / our approach: monarch calls load_weights() on each worker.
  • Previous buggy approach: each vllm worker (wrapped in monarch actor) is trying to circumvent the load_weights() method all together and directly operate on the state_dict() of the underlying nn.Module.


@endpoint
async def push_weights(self, version: int):
async def push_weights_DEPRECATED(self, policy_version: int): # noqa: N802
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're confident in this fix, we should just fully delete the old way. My thinking is as follows:

  1. Gets everyone immediately testing the new version for any bugs 👍
  2. Reduces the chance an end user sees and uses this endpoint 👍
  3. Less code to parse through right now 👍

Copy link
Contributor Author

@casteryh casteryh Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gets everyone immediately testing the new version for any bugs 👍

Yes the new one is the default now! I think the plan is keep the DEPRECATED method just for benchmarking purposes now? @JenniferWang

# Instead, we just call load_weights with one parameter at a time.
for name in hf_names:
param = await ts.get(get_param_key(version, name))
loaded = model.load_weights([(name, param)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I didn't realize you could do it per-param :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's surprisingly good

loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
self.logger.info(f"Updated {len(loaded_weights)} parameters")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer the old debug message that prints out the time it took to update the weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it back

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable!

@parse
def _main(cfg):
asyncio.run(main(cfg))
with TemporaryDirectory(prefix="forge_run_", dir="/dev/shm") as dcp_path:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this won't work lemme revert

@casteryh casteryh merged commit 5f19d68 into meta-pytorch:main Sep 27, 2025
5 checks passed
@casteryh casteryh deleted the weight-loading branch September 27, 2025 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants