Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ class PolicyWorker(ForgeActor):
state_dict_key: str = "model_state_dict"
use_dcp: bool = True

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
Expand Down Expand Up @@ -501,9 +504,11 @@ async def update(self, version: int):
key = f"{self.state_dict_key}{DELIM}{version}"
model = self.worker.model_runner.model
current_state_dict = model.state_dict()
start = time.time()
start = time.perf_counter()
await self._load_tensor_parallel_state_dict(current_state_dict, version)
logger.debug(f"Loaded state dict from {key} in {time.time() - start} seconds")
logger.info(
f"Loaded state dict from {key} in {time.perf_counter() - start} seconds"
)

@endpoint
async def setup_kv_cache(self):
Expand Down
25 changes: 18 additions & 7 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def __post_init__(self):

"""
super().__init__()

if self.use_dcp:
# DCP specific optimization
torch.serialization.set_crc32_options(False)

# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
Expand Down Expand Up @@ -249,6 +254,7 @@ def train_step(
@endpoint
async def push_weights(self, policy_version: int) -> None:
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
start_time = time.perf_counter()
# TODO:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
Expand All @@ -267,14 +273,17 @@ async def push_weights(self, policy_version: int) -> None:
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
sd=hf_state_dict, num_layers=self.engine.model_args.n_layers
)

conversion_time = time.perf_counter()
key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
if self.use_dcp:

# TODO - DCP should probably be being saved to NFS explicitly?
# Right now it will only save everything locally
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
storage_writer = torch.distributed.checkpoint.FileSystemWriter(
key, single_file_per_rank=False, thread_count=8
)
metadata = dcp.save(
storage_writer=storage_writer, state_dict=vllm_ready_hf_sd
)
await ts.put(key, metadata)

# Delete old weight versions if they exist
Expand All @@ -286,9 +295,11 @@ async def push_weights(self, policy_version: int) -> None:
)
else:
await ts.put_state_dict(vllm_ready_hf_sd, key)
end_time = time.time()

logger.debug(f"Pushed weights to {key} in {end_time - start_time:.2f} seconds")
end_time = time.perf_counter()
logger.info(
f"Completed weights push to {key} in {end_time - start_time:.2f} seconds "
f"(to_vllm: {conversion_time - start_time:.2f}s, tranport time: {end_time - conversion_time:.2f})"
)

@endpoint
async def cleanup(self) -> None:
Expand Down
Loading