diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 7f183870d..de855d1cb 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -28,7 +28,6 @@ dataset: # Policy configuration policy: - use_vllm_builtin_load: true engine_config: model: ${model} tensor_parallel_size: 2 @@ -43,7 +42,6 @@ policy: # Trainer configuration trainer: use_dcp: true - use_vllm_builtin_load: true model: name: qwen3 flavor: 8B diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 80fced0dd..30a1d39ee 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -10,7 +10,6 @@ import logging import os import sys -import time from collections.abc import Mapping from copy import copy from dataclasses import asdict, dataclass, field, fields @@ -140,7 +139,6 @@ def create_vllm_config(self) -> VllmConfig: class Policy(PolicyInterface): engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) - use_vllm_builtin_load: bool = True available_devices: str | None = None use_dcp: bool = True # Gets set up by setup @@ -484,10 +482,7 @@ async def update_weights(self, policy_version: int): record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) logger.debug(f"Starting weight update on {self.__class__.__name__}") - if self.use_vllm_builtin_load: - await self.policy_worker.update.call(version=policy_version) - else: - await self.policy_worker.update_DEPRECATED.call(version=policy_version) + await self.policy_worker.update.call(version=policy_version) self.policy_version = policy_version # After updating the weights, we need to reset the KV cache @@ -504,18 +499,6 @@ async def update_weights(self, policy_version: int): async def _reset_prefix_cache(self): self.scheduler.reset_prefix_cache() - @endpoint - async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802 - # TODO: If generating long sequences, this might be long and will block policy weight updates - curr_requests = [fut for _, fut in self.requests.values()] - if curr_requests: - logger.debug(f"Waiting for {len(curr_requests)} pending requests") - await asyncio.gather(*curr_requests) - - await self.policy_worker.update_DEPRECATED.call(version=policy_version) - self.policy_version = policy_version - logger.info(f"Weight update completed (now v{self.policy_version})") - @endpoint async def get_version(self) -> int: """Get the current policy version.""" @@ -636,19 +619,6 @@ async def _load_tensor_parallel_state_dict( current_tensor, ) - @endpoint - async def update_DEPRECATED(self, version: int): # noqa: N802 - """Update model weights by reading state dict from torchstore. - Deprecated. This uses manual sharding logic which is buggy.""" - key = f"{self.state_dict_key}{DELIM}{version}" - model = self.worker.model_runner.model - current_state_dict = model.state_dict() - start = time.perf_counter() - await self._load_tensor_parallel_state_dict(current_state_dict, version) - logger.info( - f"Loaded state dict from {key} in {time.perf_counter() - start} seconds" - ) - @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 7a399e4f8..f4199db71 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -21,7 +21,6 @@ from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict -from torchstore.state_dict_utils import DELIM from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -114,8 +113,6 @@ class RLTrainer(ForgeActor): state_dict_key: str = "model_state_dict" use_dcp: bool = True dcp_path: str = "forge_dcp_tmp" - vllm_tp_DEPRECATED: int = 1 # noqa: N815 - use_vllm_builtin_load: bool = True def __post_init__(self): """Initializes config types and env variables. @@ -159,8 +156,6 @@ def __post_init__(self): "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", } os.environ.update(env) - - # compile loss logger.info("Compiling loss") self.loss = torch.compile(self.loss) @@ -172,9 +167,7 @@ async def setup(self): "loss", "state_dict_key", "use_dcp", - "use_vllm_builtin_load", "dcp_path", - "vllm_tp_DEPRECATED", }: engine_config.pop(key) # Not part of job config self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) @@ -306,76 +299,12 @@ async def train_step( t.stop() return loss - @endpoint - async def push_weights_DEPRECATED( # noqa: N802 - self, policy_version: int, vllm_tp_DEPRECATED: int = 1 - ) -> None: # noqa: N802 - """[Deprecated] This method pushes weights to torchstore in the vllm format, - which is buggy and not scalable to other models. - Deprecated in favor of push_weights.""" - return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED) - - async def _push_weights_DEPRECATED( # noqa: N802 - self, policy_version: int, vllm_tp_DEPRECATED: int - ) -> None: # noqa: N802 - # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. - # TODO: - # 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. - # May need to replicate the same in this code path. - # 2. Unify CheckpointManager and TorchStore weights save control path. - if "model" not in self.engine.checkpointer.states: - raise RuntimeError("Model state not found in checkpointer state") - - sd = self.engine.checkpointer.states["model"].state_dict() - flattened_state_dict, _ = flatten_state_dict(sd) - - if self.engine.checkpointer.sd_adapter is None: - raise RuntimeError( - "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." - ) - hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) - - # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed - vllm_ready_hf_sd = _qwen3_hf_to_vllm( - sd=hf_state_dict, - num_layers=self.engine.model_args.n_layers, - vllm_tp=vllm_tp_DEPRECATED, - ) - - key = f"{self.state_dict_key}{DELIM}{policy_version}" - if self.use_dcp: - # TODO - DCP should probably be being saved to NFS explicitly? - # Right now it will only save everything locally - storage_writer = torch.distributed.checkpoint.FileSystemWriter( - key, single_file_per_rank=False, thread_count=8 - ) - metadata = dcp.save( - storage_writer=storage_writer, state_dict=vllm_ready_hf_sd - ) - await ts.put(key, metadata) - - # Delete old weight versions if they exist - if self.rank == 0: - cleanup_old_weight_versions( - state_dict_key=self.state_dict_key, - delim=DELIM, - current_policy_version=policy_version, - ) - else: - await ts.put_state_dict(vllm_ready_hf_sd, key) - @endpoint async def push_weights(self, policy_version: int) -> None: """Push weights to torchstore in HF format.""" t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True) t.start() logger.info(f"Pushing weights for policy version {policy_version}") - if not self.use_vllm_builtin_load: - result = await self._push_weights_DEPRECATED( - policy_version, self.vllm_tp_DEPRECATED - ) - t.step("push_weights_DEPRECATED") - return result start_time = time.perf_counter() if "model" not in self.engine.checkpointer.states: diff --git a/tests/sandbox/toy_rl/sumdigits-tp.yaml b/tests/sandbox/toy_rl/sumdigits-tp.yaml index 87f58d5ea..f859b1e7c 100644 --- a/tests/sandbox/toy_rl/sumdigits-tp.yaml +++ b/tests/sandbox/toy_rl/sumdigits-tp.yaml @@ -23,13 +23,11 @@ policy: max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 1.0 - use_vllm_builtin_load: true # Trainer configuration trainer: model_name: ${model} learning_rate: 1e-5 - use_vllm_builtin_load: true # Reference model configuration ref_model: diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 14b5f6ebe..f6e66a141 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -8,7 +8,6 @@ import asyncio import random -import time import uuid from dataclasses import dataclass from typing import Any @@ -19,7 +18,6 @@ from forge.actors._torchstore_utils import get_param_key from forge.actors.policy import Policy from forge.actors.replay_buffer import ReplayBuffer -from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown @@ -31,7 +29,6 @@ from monarch.actor import endpoint from omegaconf import DictConfig -from torchstore.state_dict_utils import DELIM from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer @@ -255,7 +252,6 @@ class Trainer(ForgeActor): learning_rate: float = 1e-5 device: torch.device | None = None state_dict_key: str = "model_state_dict" - use_vllm_builtin_load: bool = True def __post_init__(self): super().__init__() @@ -341,38 +337,9 @@ def train_step(self, episodes: list[Episode]) -> float: self.optimizer.zero_grad(set_to_none=True) return loss.item() - @endpoint - async def push_weights_DEPRECATED( # noqa: N802 - self, policy_version: int, vllm_tp_DEPRECATED: int = 1 - ): - """Update policy model weights with trainer's current weights. - This method pushes weights to torchstore in the vllm format, - which is buggy and not scalable to other models. Deprecated. - """ - return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED) - - async def _push_weights_DEPRECATED( # noqa: N802 - self, version: int, vllm_tp_DEPRECATED: int - ) -> None: - """Update policy model weights with trainer's current weights.""" - key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id - new_sd = _qwen3_hf_to_vllm( - self.model.state_dict(), - num_layers=self.model.config.num_hidden_layers, - vllm_tp=vllm_tp_DEPRECATED, - ) - start_time = time.time() - await ts.put_state_dict(new_sd, key) - end_time = time.time() - self.logger.debug( - f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" - ) - @endpoint async def push_weights(self, policy_version: int) -> None: """Push weights to torchstore in HF format.""" - if not self.use_vllm_builtin_load: - return await self._push_weights_DEPRECATED(policy_version) hf_state_dict = self.model.state_dict() for name, param in hf_state_dict.items(): key = get_param_key(policy_version, name)