Skip to content

Commit 17e0c05

Browse files
committed
use vllm load_weights() in GRPO
1 parent d4fb5e1 commit 17e0c05

File tree

3 files changed

+69
-53
lines changed

3 files changed

+69
-53
lines changed

apps/grpo/main.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from forge.actors.policy import Policy
2020
from forge.actors.reference_model import ReferenceModel # noqa: F401
2121
from forge.actors.replay_buffer import ReplayBuffer
22+
from forge.actors.torchstore_utils import get_param_key
2223
from forge.actors.trainer import _qwen3_hf_to_vllm
2324
from forge.cli.config import parse
2425
from forge.controller.actor import ForgeActor
@@ -185,14 +186,13 @@ async def train_step(self, batch: list[list[Episode]]):
185186
@endpoint
186187
async def push_weights(self, version: int):
187188
"""Update policy model weights with trainer's current weights."""
188-
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
189-
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
190-
start_time = time.time()
191-
await ts.put_state_dict(new_sd, key)
192-
end_time = time.time()
193-
self.logger.debug(
194-
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
195-
)
189+
start_time = time.perf_counter()
190+
hf_state_dict = self.model.state_dict()
191+
for name, param in hf_state_dict.items():
192+
key = get_param_key(version, name)
193+
await ts.put(key, param)
194+
end_time = time.perf_counter()
195+
self.logger.debug(f"Pushed weights in {end_time - start_time:.2f} seconds")
196196

197197

198198
@dataclass
@@ -318,7 +318,7 @@ async def main(cfg: DictConfig):
318318
mlogger = get_metric_logger(
319319
"wandb",
320320
freq=1,
321-
project="grpo-training",
321+
project="yuxuanh-grpo-training-debug",
322322
)
323323

324324
# ---- Setup services ---- #
@@ -397,20 +397,28 @@ async def continuous_rollouts():
397397

398398
async def continuous_training():
399399
training_step = 0
400-
policy_version = 0
401400
while True:
402-
batch = await replay_buffer.sample.choose(
403-
curr_policy_version=policy_version
404-
)
401+
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
405402
if batch is None:
406403
await asyncio.sleep(0.1)
407404
else:
408405
loss = await trainer.train_step.choose(batch)
409406
training_step += 1
410407
mlogger.log("loss/training_step", loss, training_step)
411-
await trainer.push_weights.call(policy_version)
412-
policy_version += 1
413-
await policy.update_weights.call()
408+
start_time = time.perf_counter()
409+
await trainer.push_weights.call(training_step)
410+
mlogger.log(
411+
"push_weights_time/training_step",
412+
time.perf_counter() - start_time,
413+
training_step,
414+
)
415+
start_time = time.perf_counter()
416+
await policy.update_weights.call(training_step)
417+
mlogger.log(
418+
"update_weights_time/training_step",
419+
time.perf_counter() - start_time,
420+
training_step,
421+
)
414422

415423
print("Starting GRPO training loops...")
416424
# TODO: Start multiple rollouts once all serivces support it

src/forge/actors/policy.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch
1818
import torchstore as ts
1919
from monarch.actor import current_rank, endpoint, ProcMesh
20-
from torchstore.state_dict_utils import DELIM
2120
from vllm.config import VllmConfig
2221

2322
from vllm.engine.arg_utils import EngineArgs
@@ -40,11 +39,17 @@
4039
from vllm.v1.structured_output import StructuredOutputManager
4140
from vllm.worker.worker_base import WorkerWrapperBase
4241

43-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
42+
from forge.actors.torchstore_utils import (
43+
extract_param_name,
44+
get_param_key,
45+
get_param_prefix,
46+
)
4447

48+
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
4549
from forge.data.sharding import VLLMSharding
4650
from forge.interfaces import Policy as PolicyInterface
4751
from forge.types import ProcessConfig
52+
from forge.util.async_utils import make_sync_generator
4853

4954

5055
@dataclass
@@ -364,16 +369,16 @@ async def run(self):
364369
fut.set_result(request_output)
365370

366371
@endpoint
367-
async def update_weights(self):
372+
async def update_weights(self, policy_version: int):
368373
# TODO: If generating long sequences, this might be long and will block policy weight updates
369374
curr_requests = [fut for _, fut in self.requests.values()]
370375
if curr_requests:
371376
self.logger.debug(f"Waiting for {len(curr_requests)} pending requests")
372377
await asyncio.gather(*curr_requests)
373378

374379
self.logger.debug(f"Starting weight update on {self.__class__.__name__}")
375-
await self.policy_worker.update.call(version=self.weights_version)
376-
self.weights_version += 1
380+
await self.policy_worker.update.call(version=policy_version)
381+
self.weights_version = policy_version
377382
self.logger.info(f"Weight update completed (now v{self.weights_version})")
378383

379384
@endpoint
@@ -395,7 +400,6 @@ async def stop(self):
395400
@dataclass
396401
class PolicyWorker(ForgeActor):
397402
vllm_config: VllmConfig
398-
state_dict_key: str = "model_state_dict"
399403

400404
@endpoint
401405
async def setup(self):
@@ -407,41 +411,26 @@ async def setup(self):
407411
async def execute_model(self, schedule: SchedulerOutput):
408412
return self.worker.execute_model(schedule)
409413

410-
async def _load_tensor_parallel_state_dict(
411-
self, current_state_dict: dict, version: int
412-
):
413-
"""
414-
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
415-
"""
416-
sharding = VLLMSharding(
417-
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
418-
)
419-
420-
for param_name in current_state_dict.keys():
421-
current_tensor = current_state_dict[param_name]
422-
423-
# Load the full tensor from torchstore
424-
# TODO: only get the part of the tensor that is needed
425-
stored_tensor = await ts.get(
426-
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
427-
)
428-
sharding.load_from_source_to_target(
429-
param_name,
430-
stored_tensor,
431-
current_tensor,
432-
)
433-
434414
@endpoint
435415
async def update(self, version: int):
436416
"""Update model weights by reading state dict from torchstore"""
437-
key = f"{self.state_dict_key}{DELIM}{version}"
438417
model = self.worker.model_runner.model
439-
current_state_dict = model.state_dict()
440-
start = time.time()
441-
await self._load_tensor_parallel_state_dict(current_state_dict, version)
442-
self.logger.debug(
443-
f"Loaded state dict from {key} in {time.time() - start} seconds"
444-
)
418+
prefix = get_param_prefix(version)
419+
self.logger.debug(f"{prefix=}")
420+
matching_keys = await ts.keys(prefix)
421+
self.logger.debug(f"{matching_keys=}")
422+
# TODO: find a way to save the original huggingface parameter names.
423+
hf_names = [extract_param_name(key) for key in matching_keys]
424+
self.logger.debug(f"{hf_names=}")
425+
loaded_weights = set()
426+
# We can't pass a generator since vllm load_weights is not async.
427+
# Instead, we just call load_weights with one parameter at a time.
428+
for name in hf_names:
429+
param = await ts.get(get_param_key(version, name))
430+
loaded = model.load_weights([(name, param)])
431+
del param
432+
loaded_weights.update(loaded)
433+
self.logger.info(f"Updated {len(loaded_weights)} parameters")
445434

446435
@endpoint
447436
async def setup_kv_cache(self):
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
KEY_DELIM = "."
8+
9+
10+
def get_param_prefix(policy_version: int) -> str:
11+
return f"policy_ver_{policy_version}"
12+
13+
14+
def get_param_key(policy_version: int, name: str) -> str:
15+
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"
16+
17+
18+
def extract_param_name(key: str) -> str:
19+
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])

0 commit comments

Comments
 (0)