Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def main(cfg: DictConfig):
)

# ---- Setup services ---- #
await ts.initialize()
await ts.initialize(strategy=ts.ControllerStorageVolumes())
(
dataloader,
policy,
Expand Down
20 changes: 17 additions & 3 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import asdict, dataclass, field, fields

import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
Expand Down Expand Up @@ -399,6 +400,7 @@ async def stop(self):
class PolicyWorker(ForgeActor):
vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

@endpoint
async def setup(self):
Expand All @@ -420,14 +422,26 @@ async def _load_tensor_parallel_state_dict(
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
)

checkpoint_id = f"{self.state_dict_key}{DELIM}{version}"
dcp_metadata = None
if self.use_dcp:
dcp_metadata = await ts.get(checkpoint_id)

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await ts.get(
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
)
if self.use_dcp:
tensor_meta = dcp_metadata.state_dict_metadata[param_name]
stored_tensor = torch.empty(
size=tensor_meta.size, dtype=tensor_meta.properties.dtype
)
dcp.load(
checkpoint_id=checkpoint_id, state_dict={param_name: stored_tensor}
)
else:
stored_tensor = await ts.get(f"{checkpoint_id}{DELIM}{param_name}")
sharding.load_from_source_to_target(
param_name,
stored_tensor,
Expand Down
13 changes: 11 additions & 2 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Callable

import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from monarch.actor import current_rank, current_size, endpoint
Expand Down Expand Up @@ -53,6 +54,7 @@ class RLTrainer(ForgeActor):
comm: Comm = field(default_factory=Comm)
loss: Callable = lambda logits, **targets: logits
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

def __post_init__(self):
"""Initializes config types and env variables.
Expand Down Expand Up @@ -95,7 +97,7 @@ def __post_init__(self):
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {"loss", "state_dict_key"}:
for key in {"loss", "state_dict_key", "use_dcp"}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.current_step)
Expand Down Expand Up @@ -207,6 +209,7 @@ async def push_weights(self, policy_version: int) -> None:
# 2. Unify CheckpointManager and TorchStore weights save control path.
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
if self.engine.checkpointer.sd_adapter is None:
Expand All @@ -216,10 +219,16 @@ async def push_weights(self, policy_version: int) -> None:
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)

key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key)
if self.use_dcp:
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
await ts.put(key, metadata)
else:
await ts.put_state_dict(vllm_ready_hf_sd, key)
end_time = time.time()

self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)
Expand Down
Loading