Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
import torch.distributed.checkpoint as dcp
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel # noqa: F401
from forge.actors.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -127,6 +128,8 @@ class Trainer(ForgeActor):
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now
use_dcp: bool = True


@endpoint
async def setup(self):
Expand Down Expand Up @@ -188,7 +191,12 @@ async def push_weights(self, version: int):
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
if self.use_dcp:
metadata = dcp.save(checkpoint_id=key, state_dict=new_sd)
await ts.put(key, metadata)
else:
await ts.put_state_dict(new_sd, key)

end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
Expand Down Expand Up @@ -322,7 +330,9 @@ async def main(cfg: DictConfig):
)

# ---- Setup services ---- #
await ts.initialize()
await ts.initialize(
strategy=ts.ControllerStorageVolumes()
)
(
dataloader,
policy,
Expand Down
21 changes: 18 additions & 3 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
import torch.distributed.checkpoint as dcp
from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -122,6 +123,7 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
use_dcp: bool = True

def __post_init__(self):
self._run_task: asyncio.Task | None = None
Expand Down Expand Up @@ -416,14 +418,27 @@ async def _load_tensor_parallel_state_dict(
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
)

checkpoint_id = f"{self.state_dict_key}{DELIM}{version}"
dcp_metadata = None
if self.use_dcp:
dcp_metadata = await ts.get()

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await ts.get(
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
)
if self.use_dcp:
tensor_meta = dcp_metadata.state_dict_metadata[param_name]
stored_tensor = torch.empty(size=tensor_meta.size, dtype=tensor_meta.properties.dtype)
dcp.load(
checkpoint_id=checkpoint_id,
state_dict={param_name: stored_tensor}
)
else:
stored_tensor = await ts.get(
f"{checkpoint_id}{DELIM}{param_name}"
)
sharding.load_from_source_to_target(
param_name,
stored_tensor,
Expand Down
Loading