Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
17e0c05
use vllm load_weights() in GRPO
casteryh Sep 19, 2025
2657324
Merge branch 'main' into weight-loading
casteryh Sep 23, 2025
6b67be9
modify trainer.py
casteryh Sep 23, 2025
a6c7aef
modify trainer.py
casteryh Sep 23, 2025
7ec461b
fix
casteryh Sep 23, 2025
762bf24
lint
casteryh Sep 23, 2025
87a7bc2
merge main
casteryh Sep 23, 2025
98e6dd3
stash
casteryh Sep 24, 2025
830efcf
stash
casteryh Sep 24, 2025
5a6245c
stash
casteryh Sep 24, 2025
3cf2e32
update yaml
casteryh Sep 24, 2025
1189017
tp size must divide head num = 14
casteryh Sep 24, 2025
b9291cb
cleanup
casteryh Sep 24, 2025
46e855f
lint
casteryh Sep 24, 2025
68b91a4
add _DEPRECATED, switch default to vllm builtin loading.
casteryh Sep 26, 2025
d1e7ec6
clean up
casteryh Sep 26, 2025
4d61a58
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
d224bea
dcp support
casteryh Sep 26, 2025
8c3471c
fix and add test
casteryh Sep 26, 2025
0f67dec
rename to _torchstore_utils, add time
casteryh Sep 26, 2025
a08c96d
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
ab56d05
fix
casteryh Sep 26, 2025
33ec0ef
fix main
casteryh Sep 26, 2025
7e05b3a
tweak tmpdir
casteryh Sep 26, 2025
aaf60ec
tweak tmpdir
casteryh Sep 26, 2025
1c0afa9
fix dcp_path
casteryh Sep 26, 2025
3e5a417
debug
casteryh Sep 26, 2025
7addf9a
Merge branch 'main' into weight-loading
casteryh Sep 26, 2025
446b123
debug
casteryh Sep 27, 2025
c02e988
debug
casteryh Sep 26, 2025
30fbe46
disable new weight load for 8b for now
casteryh Sep 27, 2025
9676aa0
fix config
casteryh Sep 27, 2025
0691b43
fix config
casteryh Sep 27, 2025
fed8688
fix oom
casteryh Sep 27, 2025
26e1334
lint
casteryh Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions apps/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Toy app Training Configuration

# Global configuration
group_size: 16
batch_size: 64
max_req_tokens: 64
max_res_tokens: 64
model: "Qwen/Qwen2.5-0.5B-Instruct"

# Dataset configuration
dataset:
model: ${model}

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
use_vllm_builtin_load: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
use_vllm_builtin_load: true

# Reference model configuration
ref_model:
model_name: ${model}

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1

services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: 1
num_replicas: 1
with_gpus: true
trainer:
procs: 1
num_replicas: 1
with_gpus: true
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
26 changes: 22 additions & 4 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,32 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
# Usage: python -m apps.toy_rl.sumdigits --config apps/toy_rl/sumdigits.yaml

import asyncio
import random
import time
import uuid
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any

import torch
import torch.nn.functional as F
import torchstore as ts
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.torchstore_utils import get_param_key
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown

from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger

from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
from torch.functional import _return_counts

from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -250,10 +252,14 @@ async def forward(self, episode: Episode) -> torch.Tensor:
class Trainer(ForgeActor):
"""Reinforce Loss Trainer implementation for policy optimization."""

model_name: str
model_name: str = ""
learning_rate: float = 1e-5
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
use_vllm_builtin_load: bool = False

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self):
Expand Down Expand Up @@ -340,6 +346,9 @@ def train_step(self, episodes: list[Episode]) -> float:
@endpoint
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
if self.use_vllm_builtin_load:
await self._push_weights_hf_nonsharded(version)
return None
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(
self.model.state_dict(), num_layers=self.model.config.num_hidden_layers
Expand All @@ -351,6 +360,13 @@ async def push_weights(self, version: int):
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

async def _push_weights_hf_nonsharded(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format, non-sharded."""
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)


@dataclass
class RewardActor(ForgeActor):
Expand Down Expand Up @@ -440,6 +456,8 @@ async def main(cfg: DictConfig):
)

# ---- Setup services ---- #
print(f"{cfg.policy=}")
print(f"{cfg.services.policy=}")
await ts.initialize()
(
dataloader,
Expand Down
41 changes: 37 additions & 4 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.actors.torchstore_utils import (
extract_param_name,
get_param_key,
get_param_prefix,
)

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

Expand Down Expand Up @@ -127,6 +131,8 @@ def create_vllm_config(self) -> VllmConfig:
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
use_vllm_builtin_load: bool = False
test_blah_blah: int = 0
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
Expand All @@ -145,6 +151,7 @@ def __post_init__(self):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.sampling_config, Mapping):
self.sampling_config = SamplingConfig.from_dict(self.sampling_config)
# No conversion needed for boolean flag

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
Expand All @@ -153,6 +160,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
process_config: ProcessConfig,
engine_config: EngineConfig | Mapping = EngineConfig(),
sampling_config: SamplingConfig | Mapping = SamplingConfig(),
use_vllm_builtin_load: bool = False,
available_devices: str | None = None,
**kwargs,
) -> "Policy":
Expand Down Expand Up @@ -191,6 +199,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls,
engine_config=engine_config,
sampling_config=sampling_config,
use_vllm_builtin_load=use_vllm_builtin_load,
available_devices=available_devices,
policy_worker=workers,
)
Expand Down Expand Up @@ -384,7 +393,10 @@ async def update_weights(self, policy_version: int):
await asyncio.gather(*curr_requests)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
await self.policy_worker.update.call(version=policy_version)
if self.use_vllm_builtin_load:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, this will be the default right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like the plan

await self.policy_worker._update_hf_nonsharded.call(version=policy_version)
else:
await self.policy_worker._update_sharded.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

Expand Down Expand Up @@ -496,7 +508,7 @@ async def _load_tensor_parallel_state_dict(
)

@endpoint
async def update(self, version: int):
async def _update_sharded(self, version: int):
"""Update model weights by reading state dict from torchstore"""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
Expand All @@ -505,6 +517,27 @@ async def update(self, version: int):
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")

@endpoint
async def _update_hf_nonsharded(self, version: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this specific to hf??

Copy link
Contributor Author

@casteryh casteryh Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this just means we are pushing/reading the state dict using the hugging face format. not titan, not vllm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps call it update_DEPRECATED and update. I'd like to keep the DEPRECATED one just for A/B testing and delete it before the PTC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the choice between using get_state_dict/get_state_dict and the get/put API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also confused at the load_weights API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also confused at the load_weights API -- will it handle sharding itself? If so, should we call this function on the driver worker (0) once?

every worker(rank) has to call load_weights.
After load_weights() is called, every worker will figure out its own rank and just read its own shard when load_weights() is called.

Moreover, load_weights() supports incremental updating, i.e., if there is only one tensor in the passed in weights, it will update that part specifically (it even handles these concatenated weights as well).
For example, if you pass in (I am making up the fqn but you get the point) a single kv pair "model.layers.0.q_proj.xxx" -> full_tensor, it will actually update the q_proj part of the fused qkv_proj weight.

"""Update model weights by reading state dict from torchstore"""
model = self.worker.model_runner.model
prefix = get_param_prefix(version)
self.logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
self.logger.debug(f"{matching_keys=}")
# TODO: find a way to save the original huggingface parameter names.
hf_names = [extract_param_name(key) for key in matching_keys]
self.logger.debug(f"{hf_names=}")
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_names:
param = await ts.get(get_param_key(version, name))
loaded = model.load_weights([(name, param)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! I didn't realize you could do it per-param :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's surprisingly good

del param
loaded_weights.update(loaded)
self.logger.info(f"Updated {len(loaded_weights)} parameters")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer the old debug message that prints out the time it took to update the weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it back


@endpoint
async def setup_kv_cache(self):
"""Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches
Expand Down
19 changes: 19 additions & 0 deletions src/forge/actors/torchstore_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

KEY_DELIM = "."


def get_param_prefix(policy_version: int) -> str:
return f"policy_ver_{policy_version}"


def get_param_key(policy_version: int, name: str) -> str:
return f"policy_ver_{policy_version}{KEY_DELIM}{name}"


def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])
27 changes: 26 additions & 1 deletion src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors.torchstore_utils import get_param_key

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device

Expand Down Expand Up @@ -93,6 +95,7 @@ class RLTrainer(ForgeActor):
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
use_vllm_builtin_load: bool = False
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)
Expand Down Expand Up @@ -142,7 +145,7 @@ def __post_init__(self):
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {"loss", "state_dict_key", "use_dcp"}:
for key in {"loss", "state_dict_key", "use_dcp", "use_vllm_builtin_load"}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
Expand Down Expand Up @@ -248,6 +251,12 @@ def train_step(

@endpoint
async def push_weights(self, policy_version: int) -> None:
if self.use_vllm_builtin_load:
await self._push_weights_hf_nonsharded(policy_version)
else:
await self._push_weights_sharded(policy_version)

async def _push_weights_sharded(self, policy_version: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some confusion: in my fix, I feel that this is not sharded. The difference is whether to process the state dict or not. Basically we just need to skip this path _qwen3_hf_to_vllm and keep the rest as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I am bad at naming things. it actually has nothing to do with sharding at this point.
maybe we just call this push_weights_vllm vs push_weights_hf (or push_weights_DEPRECATED vs push_weights if you will)

# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODO:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
Expand Down Expand Up @@ -290,6 +299,22 @@ async def push_weights(self, policy_version: int) -> None:

logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds")

async def _push_weights_hf_nonsharded(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format, non-sharded."""
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)

@endpoint
async def cleanup(self) -> None:
if self.engine.checkpointer:
Expand Down
Loading