Skip to content
31 changes: 6 additions & 25 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -16,10 +15,8 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors._torchstore_utils import WeightCleaner

from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -290,23 +287,6 @@ async def pad_token(self):
return self._tokenizer.pad_token_id


async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")


async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
Expand Down Expand Up @@ -435,6 +415,7 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
weight_cleaner = WeightCleaner()

while True:
# Restart tracer when needed (initial start or after completing a training step)
Expand Down Expand Up @@ -463,9 +444,9 @@ async def continuous_training():
await policy.update_weights.fanout(training_step)
t.step("update_weights")

if training_step >= 2:
await drop_weights(training_step - 1)
t.step("drop_weights")
# weight cleanup is non-blocking, the task is executed in the background
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

Copy link
Contributor Author

@casteryh casteryh Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

I thought the point is you don't, if you just need the weight to be eventually deleted. when you do step(), the task is scheduled in the background and everything else proceeds as normal.

Also in typical async form this step would just be an async method that you'd await now or later.

Yes but in that case, if we want to schedule the task in the background and not await for it, we need to manage the task in main.py, which we supposedly don't want to do. This essentially hides the task scheduling logic in the WeightCleaner class.

Why is there an extra method called "wait"?

If you want to make sure all the scheduled tasks are indeed completed (i.e. all old weights are deleted. like you mentioned earliner), you can await weight_cleaner.wait(). Presumably this can be named better, let me know what you think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

My understanding is, in typical async code, if you don't explicitly create a task, then it will never get executed unless you await on it? I think we can also always schedule the task and return a join handle.

weight_cleaner.step(training_step)
t.step("weight_cleaner step")

t.stop()
restart_tracer = True
Expand Down
55 changes: 55 additions & 0 deletions src/forge/actors/_torchstore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import logging
import shutil
import time
from dataclasses import dataclass

import torch
import torch.distributed.checkpoint as dcp

import torchstore as ts
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,3 +73,54 @@ def extract_param_name(key: str) -> str:

def get_dcp_whole_state_dict_key(policy_version: int) -> str:
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"


class WeightCleaner:
"""Manages asynchronous cleanup of model weights across different policy versions.

This class handles the deletion of old model weights by maintaining a list of
cleanup tasks and tracking the last deleted version to avoid redundant operations.
"""

def __init__(self):
# we need to keep the task around to make sure it's not garbage collected
self._tasks = []
self._last_deleted_version = -1

def _remove_done_tasks(self):
"""Remove completed tasks from the task list to prevent memory leaks."""
self._tasks = [task for task in self._tasks if not task.done()]

def step(self, delete_up_to_version: int):
"""Schedule deletion of weights for all versions up to the specified version.

Args:
delete_up_to_version (int): The highest policy version to delete (inclusive).
All versions from last_deleted_version + 1 to this version will be deleted.
"""
self._remove_done_tasks()
if delete_up_to_version <= self._last_deleted_version:
return
for version in range(self._last_deleted_version + 1, delete_up_to_version + 1):
self._tasks.append(asyncio.create_task(drop_weights(version)))
self._last_deleted_version = delete_up_to_version

async def wait(self):
"""Wait for all scheduled deletion tasks to complete."""
await asyncio.gather(*self._tasks)


async def drop_weights(version: int):
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
await asyncio.to_thread(dcp_handle.drop)
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
logger.info(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
1 change: 1 addition & 0 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig
Expand Down
151 changes: 149 additions & 2 deletions tests/unit_tests/test_torchstore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import os
import tempfile
import unittest

from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

import torch
import torch.distributed.checkpoint as dcp
from forge.actors._torchstore_utils import DcpHandle
from forge.actors._torchstore_utils import DcpHandle, WeightCleaner

ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings(
r"ignore:.*torch.distributed"
)

ignore_coroutine_not_awaited = pytest.mark.filterwarnings(
"ignore:.*coroutine.*was never awaited.*"
)


class TestDcpHandle(unittest.TestCase):
def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]:
Expand Down Expand Up @@ -59,3 +64,145 @@ def test_dcp_handle_drop_sets_none_for_manifold(self):
self.assertEqual(handle.checkpoint_id, None)
self.assertEqual(handle.metadata, None)
self.assertEqual(handle.param_names, None)


class TestWeightCleaner(unittest.IsolatedAsyncioTestCase):
"""Test suite for WeightCleaner class."""

def setUp(self):
"""Set up test fixtures before each test method."""
self.cleaner = WeightCleaner()

@ignore_coroutine_not_awaited
def test_remove_done_tasks_with_completed_tasks(self):
"""Test _remove_done_tasks removes completed tasks."""
# Create mock tasks - some done, some not
done_task1 = MagicMock()
done_task1.done.return_value = True

done_task2 = MagicMock()
done_task2.done.return_value = True

pending_task = MagicMock()
pending_task.done.return_value = False

self.cleaner._tasks = [done_task1, pending_task, done_task2]
self.cleaner._remove_done_tasks()

# Only the pending task should remain
self.assertEqual(len(self.cleaner._tasks), 1)
self.assertEqual(self.cleaner._tasks[0], pending_task)

@ignore_coroutine_not_awaited
def test_remove_done_tasks_with_all_pending(self):
"""Test _remove_done_tasks with all tasks pending."""
pending_task1 = MagicMock()
pending_task1.done.return_value = False

pending_task2 = MagicMock()
pending_task2.done.return_value = False

self.cleaner._tasks = [pending_task1, pending_task2]
self.cleaner._remove_done_tasks()

# All tasks should remain
self.assertEqual(len(self.cleaner._tasks), 2)
self.assertEqual(self.cleaner._tasks, [pending_task1, pending_task2])

@ignore_coroutine_not_awaited
@pytest.mark.asyncio
@patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock)
@patch("asyncio.create_task")
async def test_step_no_cleanup_needed_equal_version(
self, mock_create_task, mock_drop_weights
):
"""Test step method when delete_up_to_version equals last_deleted_version."""
future = asyncio.Future()
future.set_result(None)
mock_create_task.return_value = future

# Step it to 5 first
self.cleaner.step(delete_up_to_version=5)
mock_drop_weights.assert_called()
mock_create_task.assert_called()

# Reset mock state to clear call history
mock_drop_weights.reset_mock()
mock_create_task.reset_mock()

# Request deletion up to version 5 (already deleted)
self.cleaner.step(delete_up_to_version=5)

# No tasks should be created
mock_create_task.assert_not_called()
mock_drop_weights.assert_not_called()

@ignore_coroutine_not_awaited
@pytest.mark.asyncio
@patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock)
@patch("asyncio.create_task")
async def test_step_no_cleanup_needed_lower_version(
self, mock_create_task, mock_drop_weights
):
"""Test step method when delete_up_to_version is lower than last_deleted_version."""
future = asyncio.Future()
future.set_result(None)
mock_create_task.return_value = future

# Step it to 10 first
self.cleaner.step(delete_up_to_version=10)

# Reset mock state to clear call history
mock_drop_weights.reset_mock()
mock_create_task.reset_mock()

# Request deletion up to version 5 (lower than already deleted)
self.cleaner.step(delete_up_to_version=5)

# No tasks should be created
mock_create_task.assert_not_called()
mock_drop_weights.assert_not_called()

@ignore_coroutine_not_awaited
@pytest.mark.asyncio
@patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock)
@patch("asyncio.create_task")
async def test_step_creates_tasks_initial_call(
self, mock_create_task, mock_drop_weights
):
"""Test step method creates tasks for entire version range."""

future = asyncio.Future()
future.set_result(None)
mock_create_task.return_value = future

# Request deletion up to version 5 from initial state
self.cleaner.step(delete_up_to_version=5)

# Should create 6 tasks (versions 0 through 5)
self.assertEqual(mock_create_task.call_count, 6)

@ignore_coroutine_not_awaited
@pytest.mark.asyncio
@patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock)
@patch("asyncio.create_task")
async def test_step_creates_only_new_version_tasks(
self, mock_create_task, mock_drop_weights
):
"""Test step method only creates tasks for versions not yet deleted."""

future = asyncio.Future()
future.set_result(None)
mock_create_task.return_value = future

# First deletion up to version 3
self.cleaner.step(delete_up_to_version=3)

# Reset mock to track only new calls
mock_create_task.reset_mock()

# Second deletion up to version 7
self.cleaner.step(delete_up_to_version=7)

# Should only create tasks for versions 4, 5, 6, 7
self.assertEqual(mock_create_task.call_count, 4)
Loading