Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
17e0c05
use vllm load_weights() in GRPO
casteryh Sep 19, 2025
2657324
Merge branch 'main' into weight-loading
casteryh Sep 23, 2025
6b67be9
modify trainer.py
casteryh Sep 23, 2025
a6c7aef
modify trainer.py
casteryh Sep 23, 2025
7ec461b
fix
casteryh Sep 23, 2025
762bf24
lint
casteryh Sep 23, 2025
87a7bc2
merge main
casteryh Sep 23, 2025
98e6dd3
stash
casteryh Sep 24, 2025
830efcf
stash
casteryh Sep 24, 2025
5a6245c
stash
casteryh Sep 24, 2025
3cf2e32
update yaml
casteryh Sep 24, 2025
1189017
tp size must divide head num = 14
casteryh Sep 24, 2025
b9291cb
cleanup
casteryh Sep 24, 2025
46e855f
lint
casteryh Sep 24, 2025
68b91a4
add _DEPRECATED, switch default to vllm builtin loading.
casteryh Sep 26, 2025
d1e7ec6
clean up
casteryh Sep 26, 2025
4d61a58
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
d224bea
dcp support
casteryh Sep 26, 2025
8c3471c
fix and add test
casteryh Sep 26, 2025
0f67dec
rename to _torchstore_utils, add time
casteryh Sep 26, 2025
a08c96d
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
ab56d05
fix
casteryh Sep 26, 2025
33ec0ef
fix main
casteryh Sep 26, 2025
7e05b3a
tweak tmpdir
casteryh Sep 26, 2025
aaf60ec
tweak tmpdir
casteryh Sep 26, 2025
1c0afa9
fix dcp_path
casteryh Sep 26, 2025
3e5a417
debug
casteryh Sep 26, 2025
7addf9a
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
446b123
debug
casteryh Sep 27, 2025
c02e988
debug
casteryh Sep 26, 2025
30fbe46
disable new weight load for 8b for now
casteryh Sep 27, 2025
9676aa0
fix config
casteryh Sep 27, 2025
0691b43
fix config
casteryh Sep 27, 2025
fed8688
fix oom
casteryh Sep 27, 2025
26e1334
lint
casteryh Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import uuid
from dataclasses import dataclass
from tempfile import TemporaryDirectory
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -358,9 +359,7 @@ async def continuous_training():
loss = await trainer.train_step.route(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.fanout(
training_step, vllm_tp_DEPRECATED=policy_tp_size
)
await trainer.push_weights.fanout(training_step)
await policy.update_weights.fanout(training_step)

print("Starting GRPO training loops...")
Expand Down Expand Up @@ -394,6 +393,10 @@ async def continuous_training():

@parse
def _main(cfg):
asyncio.run(main(cfg))
with TemporaryDirectory(prefix="forge_run_", dir="/dev/shm") as dcp_path:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this won't work lemme revert

print(f"Using DCP path: {dcp_path}")
cfg.trainer.dcp_path = dcp_path
print(cfg)
asyncio.run(main(cfg))

_main() # @parse grabs the cfg from CLI
3 changes: 3 additions & 0 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: true
model:
name: qwen3
flavor: 8B
Expand All @@ -48,6 +49,7 @@ trainer:
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
Expand Down Expand Up @@ -86,6 +88,7 @@ ref_model:
hf_assets_path: hf://${model}
training:
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
Expand Down
68 changes: 68 additions & 0 deletions apps/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Toy app Training Configuration

# Global configuration
group_size: 16
batch_size: 64
max_req_tokens: 64
max_res_tokens: 64
model: "Qwen/Qwen2.5-0.5B-Instruct"

# Dataset configuration
dataset:
model: ${model}

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
use_vllm_builtin_load: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
use_vllm_builtin_load: true

# Reference model configuration
ref_model:
model_name: ${model}

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1

services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: 1
num_replicas: 1
with_gpus: true
trainer:
procs: 1
num_replicas: 1
with_gpus: true
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
36 changes: 32 additions & 4 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
# Usage: python -m apps.toy_rl.sumdigits --config apps/toy_rl/sumdigits.yaml

import asyncio
import random
Expand All @@ -18,13 +18,14 @@
import torchstore as ts
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.torchstore_utils import get_param_key
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown

from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger

from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -250,10 +251,14 @@ async def forward(self, episode: Episode) -> torch.Tensor:
class Trainer(ForgeActor):
"""Reinforce Loss Trainer implementation for policy optimization."""

model_name: str
model_name: str = ""
learning_rate: float = 1e-5
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
use_vllm_builtin_load: bool = True

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self):
Expand Down Expand Up @@ -338,7 +343,18 @@ def train_step(self, episodes: list[Episode]) -> float:
return loss.item()

@endpoint
async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None:
async def push_weights_DEPRECATED( # noqa: N802
self, policy_version: int, vllm_tp_DEPRECATED: int = 1
):
"""Update policy model weights with trainer's current weights.
This method pushes weights to torchstore in the vllm format,
which is buggy and not scalable to other models. Deprecated.
"""
return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)

async def _push_weights_DEPRECATED( # noqa: N802
self, version: int, vllm_tp_DEPRECATED: int
) -> None:
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(
Expand All @@ -353,6 +369,16 @@ async def push_weights(self, version: int, vllm_tp_DEPRECATED: int) -> None:
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
if not self.use_vllm_builtin_load:
return await self._push_weights_DEPRECATED(policy_version)
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)


@dataclass
class RewardActor(ForgeActor):
Expand Down Expand Up @@ -444,6 +470,8 @@ async def main(cfg: DictConfig):
)

# ---- Setup services ---- #
print(f"{cfg.policy=}")
print(f"{cfg.services.policy=}")
await ts.initialize()
(
dataloader,
Expand Down
37 changes: 37 additions & 0 deletions src/forge/actors/_torchstore_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta

KEY_DELIM = "."


@dataclass
class DcpHandle:
checkpoint_id: str = ""
metadata: DcpMeta | None = None


def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor:
tensor_meta = handle.metadata.state_dict_metadata[param_name]
buffer = torch.empty(tensor_meta.size, dtype=tensor_meta.properties.dtype)
dcp.load(checkpoint_id=handle.checkpoint_id, state_dict={param_name: buffer})
return buffer


def get_param_prefix(policy_version: int) -> str:
return f"policy_ver_{policy_version}"


def get_param_key(policy_version: int, name: str) -> str:
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"


def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])
73 changes: 68 additions & 5 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,18 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.actors._torchstore_utils import (
DcpHandle,
extract_param_name,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

Expand Down Expand Up @@ -126,6 +132,7 @@ def create_vllm_config(self) -> VllmConfig:
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
use_vllm_builtin_load: bool = True
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
Expand All @@ -144,6 +151,7 @@ def __post_init__(self):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.sampling_config, Mapping):
self.sampling_config = SamplingConfig.from_dict(self.sampling_config)
# No conversion needed for boolean flag

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
Expand Down Expand Up @@ -196,6 +204,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
sampling_config=sampling_config,
available_devices=available_devices,
policy_worker=workers,
**kwargs,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
Expand Down Expand Up @@ -387,7 +396,22 @@ async def update_weights(self, policy_version: int):
await asyncio.gather(*curr_requests)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=policy_version)
if self.use_vllm_builtin_load:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, this will be the default right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the plan

await self.policy_worker.update.call(version=policy_version)
else:
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802
# TODO: If generating long sequences, this might be long and will block policy weight updates
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
await asyncio.gather(*curr_requests)

await self.policy_worker.update_DEPRECATED.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

Expand Down Expand Up @@ -454,7 +478,11 @@ def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None
class PolicyWorker(ForgeActor):
vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"
# TODO: remove this later since no plumbing exists to change this value.
# Also, whether to use dcp or not can be inferred from torchstore get() call.
use_dcp: bool = True
# Cache hf param names on first update call.
hf_param_names = []

# used for tesing purposes only
_test_prev_params = {}
Expand Down Expand Up @@ -509,8 +537,9 @@ async def _load_tensor_parallel_state_dict(
)

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
async def update_DEPRECATED(self, version: int): # noqa: N802
"""Update model weights by reading state dict from torchstore.
Deprecated. This uses manual sharding logic which is buggy."""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
Expand All @@ -520,6 +549,40 @@ async def update(self, version: int):
f"Loaded state dict from {key} in {time.perf_counter() - start} seconds"
)

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
model = self.worker.model_runner.model
prefix = get_param_prefix(version)
logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
logger.debug(f"{matching_keys=}")
if not self.hf_param_names:
self.hf_param_names = [extract_param_name(key) for key in matching_keys]
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
start = time.perf_counter()
for name in self.hf_param_names:
param_key = get_param_key(version, name)
tensor_or_handle = await ts.get(param_key)
if isinstance(tensor_or_handle, torch.Tensor):
param = tensor_or_handle
elif isinstance(tensor_or_handle, DcpHandle):
param = load_tensor_from_dcp(tensor_or_handle, name)
logger.debug(f"Loaded {name} from DCP with handle {tensor_or_handle}")
else:
raise RuntimeError(
f"Unexpected type for {param_key}: {type(tensor_or_handle)}"
)
loaded = model.load_weights([(name, param)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I didn't realize you could do it per-param :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's surprisingly good

del param
loaded_weights.update(loaded)
logger.info(
f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds"
)
logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}")

@endpoint
async def setup_kv_cache(self):
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches
Expand Down
Loading
Loading