Skip to content
30 changes: 5 additions & 25 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import asyncio

import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -17,10 +16,7 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)

from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
Expand All @@ -33,6 +29,7 @@
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.util.ops import compute_logprobs
from forge.util.weight_sync import drop_weights
from monarch.actor import endpoint
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand Down Expand Up @@ -289,23 +286,6 @@ async def pad_token(self):
return self._tokenizer.pad_token_id


async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")


async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
Expand Down Expand Up @@ -455,9 +435,9 @@ async def continuous_training():
await policy.update_weights.fanout(training_step)
t.step("update_weights")

# if training_step >= 2:
# await drop_weights(training_step - 1)
# t.step("drop_weights")
if training_step >= 2:
await drop_weights(training_step - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be truly async or does it have to be blocking like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be truly async, if we just create a task and not await on it.

t.step("drop_weights")

t.stop()
restart_tracer = True
Expand Down
31 changes: 31 additions & 0 deletions src/forge/util/weight_sync.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put this in _torchstore_utils?

It seems that's where all the other weight sync information is. If that's not supposed to be the end location, I'd almost rather all _torchstore_utils be moved out into a weight_sync.py file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather all _torchstore_utils be moved out into a weight_sync.py

Maybe I will do this. Let me know wyt. Do you want me to make it _weight_sync.py instead?

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time

import torchstore as ts

from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)


async def drop_weights(version: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're here, I'd prefer a name like "delete_old_weights"

And instead of version, something like "oldest_version_to_keep"

print(f"Dropping weights @ version {version}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log instead of print

Loading