Skip to content
30 changes: 5 additions & 25 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -16,10 +15,6 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.generator import Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
Expand All @@ -34,6 +29,7 @@
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
from forge.util._torchstore import WeightCleaner
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -272,23 +268,6 @@ async def pad_token(self):
return self._tokenizer.pad_token_id


async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")


async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
Expand Down Expand Up @@ -422,6 +401,7 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
weight_cleaner = WeightCleaner()

while max_steps == -1 or training_step < max_steps:
# Restart tracer when needed (initial start or after completing a training step)
Expand Down Expand Up @@ -450,9 +430,9 @@ async def continuous_training():
await policy.update_weights.fanout(training_step)
t.step("update_weights")

if training_step >= 2:
await drop_weights(training_step - 1)
t.step("drop_weights")
# weight cleanup is non-blocking, the task is executed in the background
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

Copy link
Contributor Author

@casteryh casteryh Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

I thought the point is you don't, if you just need the weight to be eventually deleted. when you do step(), the task is scheduled in the background and everything else proceeds as normal.

Also in typical async form this step would just be an async method that you'd await now or later.

Yes but in that case, if we want to schedule the task in the background and not await for it, we need to manage the task in main.py, which we supposedly don't want to do. This essentially hides the task scheduling logic in the WeightCleaner class.

Why is there an extra method called "wait"?

If you want to make sure all the scheduled tasks are indeed completed (i.e. all old weights are deleted. like you mentioned earliner), you can await weight_cleaner.wait(). Presumably this can be named better, let me know what you think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

My understanding is, in typical async code, if you don't explicitly create a task, then it will never get executed unless you await on it? I think we can also always schedule the task and return a join handle.

weight_cleaner.step(training_step)
t.step("weight_cleaner step")

t.stop()
restart_tracer = True
Expand Down
17 changes: 9 additions & 8 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torchstore as ts

from monarch.actor import current_rank, endpoint, ProcMesh
from vllm.config import VllmConfig

Expand All @@ -40,14 +41,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.actors._torchstore_utils import (
extract_param_name,
get_dcp_whole_state_dict_key,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
)

from forge.controller import (
ForgeActor,
get_proc_mesh,
Expand All @@ -61,6 +54,14 @@
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig

from forge.util._torchstore import (
extract_param_name,
get_dcp_whole_state_dict_key,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down
12 changes: 6 additions & 6 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.env import TORCHSTORE_USE_RDMA
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.util._torchstore import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down
2 changes: 2 additions & 0 deletions src/forge/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from . import _torchstore
from .distributed import get_world_size_and_rank
from .logging import get_logger, log_once, log_rank_zero
from .metric_logging import get_metric_logger
Expand All @@ -13,4 +14,5 @@
"log_once",
"log_rank_zero",
"get_metric_logger",
"_torchstore",
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to core app/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything or only the WeightCleaner? trainer and policy both need functions in this file.

Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import logging
import shutil
import time
from dataclasses import dataclass

import torch
import torch.distributed.checkpoint as dcp

import torchstore as ts
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,3 +73,54 @@ def extract_param_name(key: str) -> str:

def get_dcp_whole_state_dict_key(policy_version: int) -> str:
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"


class WeightCleaner:
"""Manages asynchronous cleanup of model weights across different policy versions.

This class handles the deletion of old model weights by maintaining a list of
cleanup tasks and tracking the last deleted version to avoid redundant operations.
"""

def __init__(self):
# we need to keep the task around to make sure it's not garbage collected
self._tasks = []
self._last_deleted_version = -1

def _remove_done_tasks(self):
"""Remove completed tasks from the task list to prevent memory leaks."""
self._tasks = [task for task in self._tasks if not task.done()]

def step(self, delete_up_to_version: int):
"""Schedule deletion of weights for all versions up to the specified version.

Args:
delete_up_to_version (int): The highest policy version to delete (inclusive).
All versions from last_deleted_version + 1 to this version will be deleted.
"""
self._remove_done_tasks()
if delete_up_to_version <= self._last_deleted_version:
return
for version in range(self._last_deleted_version + 1, delete_up_to_version + 1):
self._tasks.append(asyncio.create_task(drop_weights(version)))
self._last_deleted_version = delete_up_to_version

async def wait(self):
"""Wait for all scheduled deletion tasks to complete."""
await asyncio.gather(*self._tasks)


async def drop_weights(version: int):
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have a 'get_meta' in torchstore (although it's lacking a proper object).

# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation specific to DCP?

Do we need something like (ts.delete(r"key.*") support in torchstore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation specific to DCP?

Yes

Do we need something like (ts.delete(r"key.*") support in torchstore?

It would be good if we can have it. Although currently it is not a bottleneck to simply call delete on every key.

if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
await asyncio.to_thread(dcp_handle.drop)
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
logger.info(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import torch.nn.functional as F
import torchstore as ts
from forge.actors._torchstore_utils import get_param_key
from forge.actors.generator import Generator
from forge.actors.replay_buffer import ReplayBuffer
from forge.cli.config import parse
Expand All @@ -25,6 +24,7 @@
from forge.observability.metric_actors import get_or_create_metric_logger

from forge.observability.metrics import record_metric, Reduce
from forge.util._torchstore import get_param_key
from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down
Loading
Loading