Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
17e0c05
use vllm load_weights() in GRPO
casteryh Sep 19, 2025
2657324
Merge branch 'main' into weight-loading
casteryh Sep 23, 2025
6b67be9
modify trainer.py
casteryh Sep 23, 2025
a6c7aef
modify trainer.py
casteryh Sep 23, 2025
7ec461b
fix
casteryh Sep 23, 2025
762bf24
lint
casteryh Sep 23, 2025
87a7bc2
merge main
casteryh Sep 23, 2025
98e6dd3
stash
casteryh Sep 24, 2025
830efcf
stash
casteryh Sep 24, 2025
5a6245c
stash
casteryh Sep 24, 2025
3cf2e32
update yaml
casteryh Sep 24, 2025
1189017
tp size must divide head num = 14
casteryh Sep 24, 2025
b9291cb
cleanup
casteryh Sep 24, 2025
46e855f
lint
casteryh Sep 24, 2025
68b91a4
add _DEPRECATED, switch default to vllm builtin loading.
casteryh Sep 26, 2025
d1e7ec6
clean up
casteryh Sep 26, 2025
4d61a58
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
d224bea
dcp support
casteryh Sep 26, 2025
8c3471c
fix and add test
casteryh Sep 26, 2025
0f67dec
rename to _torchstore_utils, add time
casteryh Sep 26, 2025
a08c96d
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
ab56d05
fix
casteryh Sep 26, 2025
33ec0ef
fix main
casteryh Sep 26, 2025
7e05b3a
tweak tmpdir
casteryh Sep 26, 2025
aaf60ec
tweak tmpdir
casteryh Sep 26, 2025
1c0afa9
fix dcp_path
casteryh Sep 26, 2025
3e5a417
debug
casteryh Sep 26, 2025
7addf9a
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
446b123
debug
casteryh Sep 27, 2025
c02e988
debug
casteryh Sep 26, 2025
30fbe46
disable new weight load for 8b for now
casteryh Sep 27, 2025
9676aa0
fix config
casteryh Sep 27, 2025
0691b43
fix config
casteryh Sep 27, 2025
fed8688
fix oom
casteryh Sep 27, 2025
26e1334
lint
casteryh Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions apps/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Toy app Training Configuration

# Global configuration
group_size: 16
batch_size: 64
max_req_tokens: 64
max_res_tokens: 64
model: "Qwen/Qwen2.5-0.5B-Instruct"

# Dataset configuration
dataset:
model: ${model}

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
use_vllm_builtin_load: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
use_vllm_builtin_load: true

# Reference model configuration
ref_model:
model_name: ${model}

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1

services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: 1
num_replicas: 1
with_gpus: true
trainer:
procs: 1
num_replicas: 1
with_gpus: true
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
32 changes: 28 additions & 4 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
# Usage: python -m apps.toy_rl.sumdigits --config apps/toy_rl/sumdigits.yaml

import asyncio
import random
Expand All @@ -18,13 +18,14 @@
import torchstore as ts
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.torchstore_utils import get_param_key
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown

from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger

from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -250,10 +251,14 @@ async def forward(self, episode: Episode) -> torch.Tensor:
class Trainer(ForgeActor):
"""Reinforce Loss Trainer implementation for policy optimization."""

model_name: str
model_name: str = ""
learning_rate: float = 1e-5
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
use_vllm_builtin_load: bool = True

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self):
Expand Down Expand Up @@ -338,7 +343,14 @@ def train_step(self, episodes: list[Episode]) -> float:
return loss.item()

@endpoint
async def push_weights(self, version: int):
async def push_weights_DEPRECATED(self, policy_version: int): # noqa: N802
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're confident in this fix, we should just fully delete the old way. My thinking is as follows:

  1. Gets everyone immediately testing the new version for any bugs 👍
  2. Reduces the chance an end user sees and uses this endpoint 👍
  3. Less code to parse through right now 👍

Copy link
Contributor Author

@casteryh casteryh Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gets everyone immediately testing the new version for any bugs 👍

Yes the new one is the default now! I think the plan is keep the DEPRECATED method just for benchmarking purposes now? @JenniferWang

"""Update policy model weights with trainer's current weights.
This method pushes weights to torchstore in the vllm format,
which is buggy and not scalable to other models. Deprecated.
"""
return await self._push_weights_DEPRECATED(policy_version)

async def _push_weights_DEPRECATED(self, version: int): # noqa: N802
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(
Expand All @@ -351,6 +363,16 @@ async def push_weights(self, version: int):
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
if not self.use_vllm_builtin_load:
return await self._push_weights_DEPRECATED(policy_version)
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)


@dataclass
class RewardActor(ForgeActor):
Expand Down Expand Up @@ -440,6 +462,8 @@ async def main(cfg: DictConfig):
)

# ---- Setup services ---- #
print(f"{cfg.policy=}")
print(f"{cfg.services.policy=}")
await ts.initialize()
(
dataloader,
Expand Down
55 changes: 50 additions & 5 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.actors.torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

Expand Down Expand Up @@ -127,6 +131,7 @@ def create_vllm_config(self) -> VllmConfig:
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
use_vllm_builtin_load: bool = True
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
Expand All @@ -145,6 +150,7 @@ def __post_init__(self):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.sampling_config, Mapping):
self.sampling_config = SamplingConfig.from_dict(self.sampling_config)
# No conversion needed for boolean flag

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
Expand All @@ -153,6 +159,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
process_config: ProcessConfig,
engine_config: EngineConfig | Mapping = EngineConfig(),
sampling_config: SamplingConfig | Mapping = SamplingConfig(),
use_vllm_builtin_load: bool = False,
available_devices: str | None = None,
**kwargs,
) -> "Policy":
Expand Down Expand Up @@ -191,6 +198,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls,
engine_config=engine_config,
sampling_config=sampling_config,
use_vllm_builtin_load=use_vllm_builtin_load,
available_devices=available_devices,
policy_worker=workers,
)
Expand Down Expand Up @@ -384,7 +392,22 @@ async def update_weights(self, policy_version: int):
await asyncio.gather(*curr_requests)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=policy_version)
if self.use_vllm_builtin_load:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, this will be the default right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the plan

await self.policy_worker.update.call(version=policy_version)
else:
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802
# TODO: If generating long sequences, this might be long and will block policy weight updates
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
await asyncio.gather(*curr_requests)

await self.policy_worker.update_DEPRECATED.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

Expand Down Expand Up @@ -496,15 +519,37 @@ async def _load_tensor_parallel_state_dict(
)

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
async def update_DEPRECATED(self, version: int): # noqa: N802
"""Update model weights by reading state dict from torchstore.
Deprecated. This uses manual sharding logic which is buggy."""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
start = time.time()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
model = self.worker.model_runner.model
prefix = get_param_prefix(version)
self.logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
self.logger.debug(f"{matching_keys=}")
# TODO: find a way to save the original huggingface parameter names.
hf_names = [extract_param_name(key) for key in matching_keys]
self.logger.debug(f"{hf_names=}")
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_names:
param = await ts.get(get_param_key(version, name))
loaded = model.load_weights([(name, param)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I didn't realize you could do it per-param :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's surprisingly good

del param
loaded_weights.update(loaded)
self.logger.info(f"Updated {len(loaded_weights)} parameters")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer the old debug message that prints out the time it took to update the weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it back


@endpoint
async def setup_kv_cache(self):
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches
Expand Down
19 changes: 19 additions & 0 deletions src/forge/actors/torchstore_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

KEY_DELIM = "."


def get_param_prefix(policy_version: int) -> str:
return f"policy_ver_{policy_version}"


def get_param_key(policy_version: int, name: str) -> str:
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"


def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])
33 changes: 31 additions & 2 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors.torchstore_utils import get_param_key

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device

Expand Down Expand Up @@ -93,6 +95,7 @@ class RLTrainer(ForgeActor):
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
use_vllm_builtin_load: bool = True
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)
Expand Down Expand Up @@ -142,7 +145,7 @@ def __post_init__(self):
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {"loss", "state_dict_key", "use_dcp"}:
for key in {"loss", "state_dict_key", "use_dcp", "use_vllm_builtin_load"}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
Expand Down Expand Up @@ -247,7 +250,13 @@ def train_step(
return loss.item()

@endpoint
async def push_weights(self, policy_version: int) -> None:
async def push_weights_DEPRECATED(self, policy_version: int) -> None: # noqa: N802
"""[Deprecated] This method pushes weights to torchstore in the vllm format,
which is buggy and not scalable to other models.
Deprecated in favor of push_weights."""
return await self._push_weights_DEPRECATED(policy_version)

async def _push_weights_DEPRECATED(self, policy_version: int) -> None: # noqa: N802
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODO:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
Expand Down Expand Up @@ -290,6 +299,26 @@ async def push_weights(self, policy_version: int) -> None:

logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds")

@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
if not self.use_vllm_builtin_load:
return await self._push_weights_DEPRECATED(policy_version)

if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)

@endpoint
async def cleanup(self) -> None:
if self.engine.checkpointer:
Expand Down
Loading