diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7f31c26c9..ae31af5e1 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -249,7 +249,7 @@ async def main(cfg: DictConfig): ) # ---- Setup services ---- # - await ts.initialize() + await ts.initialize(strategy=ts.ControllerStorageVolumes()) ( dataloader, policy, diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 070c00798..28febad25 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -15,6 +15,7 @@ from dataclasses import asdict, dataclass, field, fields import torch +import torch.distributed.checkpoint as dcp import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM @@ -399,6 +400,7 @@ async def stop(self): class PolicyWorker(ForgeActor): vllm_config: VllmConfig state_dict_key: str = "model_state_dict" + use_dcp: bool = True @endpoint async def setup(self): @@ -420,14 +422,26 @@ async def _load_tensor_parallel_state_dict( self.vllm_config.parallel_config.tensor_parallel_size, self.rank ) + checkpoint_id = f"{self.state_dict_key}{DELIM}{version}" + dcp_metadata = None + if self.use_dcp: + dcp_metadata = await ts.get(checkpoint_id) + for param_name in current_state_dict.keys(): current_tensor = current_state_dict[param_name] # Load the full tensor from torchstore # TODO: only get the part of the tensor that is needed - stored_tensor = await ts.get( - f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" - ) + if self.use_dcp: + tensor_meta = dcp_metadata.state_dict_metadata[param_name] + stored_tensor = torch.empty( + size=tensor_meta.size, dtype=tensor_meta.properties.dtype + ) + dcp.load( + checkpoint_id=checkpoint_id, state_dict={param_name: stored_tensor} + ) + else: + stored_tensor = await ts.get(f"{checkpoint_id}{DELIM}{param_name}") sharding.load_from_source_to_target( param_name, stored_tensor, diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 0368fac4d..5dffbac1e 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -12,6 +12,7 @@ from typing import Callable import torch +import torch.distributed.checkpoint as dcp import torchstore as ts from monarch.actor import current_rank, current_size, endpoint @@ -53,6 +54,7 @@ class RLTrainer(ForgeActor): comm: Comm = field(default_factory=Comm) loss: Callable = lambda logits, **targets: logits state_dict_key: str = "model_state_dict" + use_dcp: bool = True def __post_init__(self): """Initializes config types and env variables. @@ -95,7 +97,7 @@ def __post_init__(self): async def setup(self): # TODO: update ForgeEngine to not use ForgeJobConfig engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - for key in {"loss", "state_dict_key"}: + for key in {"loss", "state_dict_key", "use_dcp"}: engine_config.pop(key) # Not part of job config self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) self.engine.checkpointer.load(step=self.current_step) @@ -207,6 +209,7 @@ async def push_weights(self, policy_version: int) -> None: # 2. Unify CheckpointManager and TorchStore weights save control path. if "model" not in self.engine.checkpointer.states: raise RuntimeError("Model state not found in checkpointer state") + sd = self.engine.checkpointer.states["model"].state_dict() flattened_state_dict, _ = flatten_state_dict(sd) if self.engine.checkpointer.sd_adapter is None: @@ -216,10 +219,16 @@ async def push_weights(self, policy_version: int) -> None: hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28) + key = f"{self.state_dict_key}{DELIM}{policy_version}" start_time = time.time() - await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key) + if self.use_dcp: + metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd) + await ts.put(key, metadata) + else: + await ts.put_state_dict(vllm_ready_hf_sd, key) end_time = time.time() + self.logger.debug( f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" )