Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
17e0c05
use vllm load_weights() in GRPO
casteryh Sep 19, 2025
2657324
Merge branch 'main' into weight-loading
casteryh Sep 23, 2025
6b67be9
modify trainer.py
casteryh Sep 23, 2025
a6c7aef
modify trainer.py
casteryh Sep 23, 2025
7ec461b
fix
casteryh Sep 23, 2025
762bf24
lint
casteryh Sep 23, 2025
87a7bc2
merge main
casteryh Sep 23, 2025
98e6dd3
stash
casteryh Sep 24, 2025
830efcf
stash
casteryh Sep 24, 2025
5a6245c
stash
casteryh Sep 24, 2025
3cf2e32
update yaml
casteryh Sep 24, 2025
1189017
tp size must divide head num = 14
casteryh Sep 24, 2025
b9291cb
cleanup
casteryh Sep 24, 2025
46e855f
lint
casteryh Sep 24, 2025
68b91a4
add _DEPRECATED, switch default to vllm builtin loading.
casteryh Sep 26, 2025
d1e7ec6
clean up
casteryh Sep 26, 2025
4d61a58
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
d224bea
dcp support
casteryh Sep 26, 2025
8c3471c
fix and add test
casteryh Sep 26, 2025
0f67dec
rename to _torchstore_utils, add time
casteryh Sep 26, 2025
a08c96d
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
ab56d05
fix
casteryh Sep 26, 2025
33ec0ef
fix main
casteryh Sep 26, 2025
7e05b3a
tweak tmpdir
casteryh Sep 26, 2025
aaf60ec
tweak tmpdir
casteryh Sep 26, 2025
1c0afa9
fix dcp_path
casteryh Sep 26, 2025
3e5a417
debug
casteryh Sep 26, 2025
7addf9a
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
446b123
debug
casteryh Sep 27, 2025
c02e988
debug
casteryh Sep 26, 2025
30fbe46
disable new weight load for 8b for now
casteryh Sep 27, 2025
9676aa0
fix config
casteryh Sep 27, 2025
0691b43
fix config
casteryh Sep 27, 2025
fed8688
fix oom
casteryh Sep 27, 2025
26e1334
lint
casteryh Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel # noqa: F401
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.torchstore_utils import get_param_key
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
Expand Down Expand Up @@ -185,14 +186,13 @@ async def train_step(self, batch: list[list[Episode]]):
@endpoint
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)
start_time = time.perf_counter()
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(version, name)
await ts.put(key, param)
end_time = time.perf_counter()
self.logger.debug(f"Pushed weights in {end_time - start_time:.2f} seconds")


@dataclass
Expand Down Expand Up @@ -318,7 +318,7 @@ async def main(cfg: DictConfig):
mlogger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
project="yuxuanh-grpo-training-debug",
)

# ---- Setup services ---- #
Expand Down Expand Up @@ -397,20 +397,28 @@ async def continuous_rollouts():

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
loss = await trainer.train_step.choose(batch)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(policy_version)
policy_version += 1
await policy.update_weights.call()
start_time = time.perf_counter()
await trainer.push_weights.call(training_step)
mlogger.log(
"push_weights_time/training_step",
time.perf_counter() - start_time,
training_step,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I think let's split this diff to

  1. add weight sync counter
  2. add options to do per-tensor weight sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this for debugging myself, will do!

start_time = time.perf_counter()
await policy.update_weights.call(training_step)
mlogger.log(
"update_weights_time/training_step",
time.perf_counter() - start_time,
training_step,
)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
63 changes: 26 additions & 37 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
Expand All @@ -40,11 +39,17 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.actors.torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from forge.util.async_utils import make_sync_generator


@dataclass
Expand Down Expand Up @@ -364,16 +369,16 @@ async def run(self):
fut.set_result(request_output)

@endpoint
async def update_weights(self):
async def update_weights(self, policy_version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to rebase on this #181
I'll address the comments and merge the PR ASAP

# TODO: If generating long sequences, this might be long and will block policy weight updates
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
self.logger.debug(f"Waiting for {len(curr_requests)} pending requests")
await asyncio.gather(*curr_requests)

self.logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=self.weights_version)
self.weights_version += 1
await self.policy_worker.update.call(version=policy_version)
self.weights_version = policy_version
self.logger.info(f"Weight update completed (now v{self.weights_version})")

@endpoint
Expand All @@ -395,7 +400,6 @@ async def stop(self):
@dataclass
class PolicyWorker(ForgeActor):
vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"

@endpoint
async def setup(self):
Expand All @@ -407,41 +411,26 @@ async def setup(self):
async def execute_model(self, schedule: SchedulerOutput):
return self.worker.execute_model(schedule)

async def _load_tensor_parallel_state_dict(
self, current_state_dict: dict, version: int
):
"""
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""
sharding = VLLMSharding(
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
)

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await ts.get(
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
)
sharding.load_from_source_to_target(
param_name,
stored_tensor,
current_tensor,
)

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
self.logger.debug(
f"Loaded state dict from {key} in {time.time() - start} seconds"
)
prefix = get_param_prefix(version)
self.logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
self.logger.debug(f"{matching_keys=}")
# TODO: find a way to save the original huggingface parameter names.
hf_names = [extract_param_name(key) for key in matching_keys]
self.logger.debug(f"{hf_names=}")
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_names:
param = await ts.get(get_param_key(version, name))
loaded = model.load_weights([(name, param)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I didn't realize you could do it per-param :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's surprisingly good

del param
loaded_weights.update(loaded)
self.logger.info(f"Updated {len(loaded_weights)} parameters")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer the old debug message that prints out the time it took to update the weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it back


@endpoint
async def setup_kv_cache(self):
Expand Down
19 changes: 19 additions & 0 deletions src/forge/actors/torchstore_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

KEY_DELIM = "."


def get_param_prefix(policy_version: int) -> str:
return f"policy_ver_{policy_version}"


def get_param_key(policy_version: int, name: str) -> str:
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"


def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])