Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dataset:

# Policy configuration
policy:
use_vllm_builtin_load: true
engine_config:
model: ${model}
tensor_parallel_size: 2
Expand All @@ -43,7 +42,6 @@ policy:
# Trainer configuration
trainer:
use_dcp: true
use_vllm_builtin_load: true
model:
name: qwen3
flavor: 8B
Expand Down
32 changes: 1 addition & 31 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logging
import os
import sys
import time
from collections.abc import Mapping
from copy import copy
from dataclasses import asdict, dataclass, field, fields
Expand Down Expand Up @@ -140,7 +139,6 @@ def create_vllm_config(self) -> VllmConfig:
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
use_vllm_builtin_load: bool = True
available_devices: str | None = None
use_dcp: bool = True
# Gets set up by setup
Expand Down Expand Up @@ -484,10 +482,7 @@ async def update_weights(self, policy_version: int):
record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
if self.use_vllm_builtin_load:
await self.policy_worker.update.call(version=policy_version)
else:
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
await self.policy_worker.update.call(version=policy_version)
self.policy_version = policy_version

# After updating the weights, we need to reset the KV cache
Expand All @@ -504,18 +499,6 @@ async def update_weights(self, policy_version: int):
async def _reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

@endpoint
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802
# TODO: If generating long sequences, this might be long and will block policy weight updates
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
await asyncio.gather(*curr_requests)

await self.policy_worker.update_DEPRECATED.call(version=policy_version)
self.policy_version = policy_version
logger.info(f"Weight update completed (now v{self.policy_version})")

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
Expand Down Expand Up @@ -636,19 +619,6 @@ async def _load_tensor_parallel_state_dict(
current_tensor,
)

@endpoint
async def update_DEPRECATED(self, version: int): # noqa: N802
"""Update model weights by reading state dict from torchstore.
Deprecated. This uses manual sharding logic which is buggy."""
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
start = time.perf_counter()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.info(
f"Loaded state dict from {key} in {time.perf_counter() - start} seconds"
)

@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
Expand Down
71 changes: 0 additions & 71 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from monarch.actor import current_rank, current_size, endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torchstore.state_dict_utils import DELIM
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Expand Down Expand Up @@ -114,8 +113,6 @@ class RLTrainer(ForgeActor):
state_dict_key: str = "model_state_dict"
use_dcp: bool = True
dcp_path: str = "forge_dcp_tmp"
vllm_tp_DEPRECATED: int = 1 # noqa: N815
use_vllm_builtin_load: bool = True

def __post_init__(self):
"""Initializes config types and env variables.
Expand Down Expand Up @@ -159,8 +156,6 @@ def __post_init__(self):
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)

# compile loss
logger.info("Compiling loss")
self.loss = torch.compile(self.loss)

Expand All @@ -172,9 +167,7 @@ async def setup(self):
"loss",
"state_dict_key",
"use_dcp",
"use_vllm_builtin_load",
"dcp_path",
"vllm_tp_DEPRECATED",
}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
Expand Down Expand Up @@ -306,76 +299,12 @@ async def train_step(
t.stop()
return loss

@endpoint
async def push_weights_DEPRECATED( # noqa: N802
self, policy_version: int, vllm_tp_DEPRECATED: int = 1
) -> None: # noqa: N802
"""[Deprecated] This method pushes weights to torchstore in the vllm format,
which is buggy and not scalable to other models.
Deprecated in favor of push_weights."""
return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)

async def _push_weights_DEPRECATED( # noqa: N802
self, policy_version: int, vllm_tp_DEPRECATED: int
) -> None: # noqa: N802
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODO:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 2. Unify CheckpointManager and TorchStore weights save control path.
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)

if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)

# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
sd=hf_state_dict,
num_layers=self.engine.model_args.n_layers,
vllm_tp=vllm_tp_DEPRECATED,
)

key = f"{self.state_dict_key}{DELIM}{policy_version}"
if self.use_dcp:
# TODO - DCP should probably be being saved to NFS explicitly?
# Right now it will only save everything locally
storage_writer = torch.distributed.checkpoint.FileSystemWriter(
key, single_file_per_rank=False, thread_count=8
)
metadata = dcp.save(
storage_writer=storage_writer, state_dict=vllm_ready_hf_sd
)
await ts.put(key, metadata)

# Delete old weight versions if they exist
if self.rank == 0:
cleanup_old_weight_versions(
state_dict_key=self.state_dict_key,
delim=DELIM,
current_policy_version=policy_version,
)
else:
await ts.put_state_dict(vllm_ready_hf_sd, key)

@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True)
t.start()
logger.info(f"Pushing weights for policy version {policy_version}")
if not self.use_vllm_builtin_load:
result = await self._push_weights_DEPRECATED(
policy_version, self.vllm_tp_DEPRECATED
)
t.step("push_weights_DEPRECATED")
return result

start_time = time.perf_counter()
if "model" not in self.engine.checkpointer.states:
Expand Down
2 changes: 0 additions & 2 deletions tests/sandbox/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ policy:
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
use_vllm_builtin_load: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
use_vllm_builtin_load: true

# Reference model configuration
ref_model:
Expand Down
33 changes: 0 additions & 33 deletions tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio
import random
import time
import uuid
from dataclasses import dataclass
from typing import Any
Expand All @@ -19,7 +18,6 @@
from forge.actors._torchstore_utils import get_param_key
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
Expand All @@ -31,7 +29,6 @@
from monarch.actor import endpoint
from omegaconf import DictConfig

from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down Expand Up @@ -255,7 +252,6 @@ class Trainer(ForgeActor):
learning_rate: float = 1e-5
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
use_vllm_builtin_load: bool = True

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -341,38 +337,9 @@ def train_step(self, episodes: list[Episode]) -> float:
self.optimizer.zero_grad(set_to_none=True)
return loss.item()

@endpoint
async def push_weights_DEPRECATED( # noqa: N802
self, policy_version: int, vllm_tp_DEPRECATED: int = 1
):
"""Update policy model weights with trainer's current weights.
This method pushes weights to torchstore in the vllm format,
which is buggy and not scalable to other models. Deprecated.
"""
return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)

async def _push_weights_DEPRECATED( # noqa: N802
self, version: int, vllm_tp_DEPRECATED: int
) -> None:
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(
self.model.state_dict(),
num_layers=self.model.config.num_hidden_layers,
vllm_tp=vllm_tp_DEPRECATED,
)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
if not self.use_vllm_builtin_load:
return await self._push_weights_DEPRECATED(policy_version)
hf_state_dict = self.model.state_dict()
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
Expand Down
Loading