diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..11317b986 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -7,7 +7,6 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml import asyncio -import time import uuid from dataclasses import dataclass from typing import Any, Callable @@ -16,10 +15,6 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset -from forge.actors._torchstore_utils import ( - get_dcp_whole_state_dict_key, - get_param_prefix, -) from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer @@ -34,6 +29,7 @@ from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig +from forge.util._torchstore import WeightCleaner from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig @@ -272,23 +268,6 @@ async def pad_token(self): return self._tokenizer.pad_token_id -async def drop_weights(version: int): - print(f"Dropping weights @ version {version}") - start_time = time.perf_counter() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - # TODO: once we have something like `get_meta()` in torchstore, we can just - # query the type of the object instead of relying on keys. - dcp_key = get_dcp_whole_state_dict_key(version) - if dcp_key in matching_keys: - dcp_handle = await ts.get(dcp_key) - dcp_handle.drop() - for key in matching_keys: - await ts.delete(key) - elapsed = time.perf_counter() - start_time - print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") - - async def main(cfg: DictConfig): """Main GRPO training loop with rollout and training processes.""" group_size = cfg.group_size @@ -422,6 +401,7 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 restart_tracer = True # Flag to control when to restart tracer + weight_cleaner = WeightCleaner() while max_steps == -1 or training_step < max_steps: # Restart tracer when needed (initial start or after completing a training step) @@ -450,9 +430,9 @@ async def continuous_training(): await policy.update_weights.fanout(training_step) t.step("update_weights") - if training_step >= 2: - await drop_weights(training_step - 1) - t.step("drop_weights") + # weight cleanup is non-blocking, the task is executed in the background + weight_cleaner.step(training_step) + t.step("weight_cleaner step") t.stop() restart_tracer = True diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index e04bed5a8..00986f6a5 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -16,6 +16,7 @@ import torch import torchstore as ts + from monarch.actor import current_rank, endpoint, ProcMesh from vllm.config import VllmConfig @@ -40,14 +41,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors._torchstore_utils import ( - extract_param_name, - get_dcp_whole_state_dict_key, - get_param_key, - get_param_prefix, - load_tensor_from_dcp, -) - from forge.controller import ( ForgeActor, get_proc_mesh, @@ -61,6 +54,14 @@ from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._torchstore import ( + extract_param_name, + get_dcp_whole_state_dict_key, + get_param_key, + get_param_prefix, + load_tensor_from_dcp, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 71049bc52..9535946fc 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -37,18 +37,18 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - from forge.controller import ForgeActor from forge.data.utils import batch_to_device from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer +from forge.util._torchstore import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py index 5fb03b0f9..c4fcd4ca0 100644 --- a/src/forge/util/__init__.py +++ b/src/forge/util/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from . import _torchstore from .distributed import get_world_size_and_rank from .logging import get_logger, log_once, log_rank_zero from .metric_logging import get_metric_logger @@ -13,4 +14,5 @@ "log_once", "log_rank_zero", "get_metric_logger", + "_torchstore", ] diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/util/_torchstore.py similarity index 51% rename from src/forge/actors/_torchstore_utils.py rename to src/forge/util/_torchstore.py index bc0d55c3b..c77eeb54b 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/util/_torchstore.py @@ -3,12 +3,16 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging import shutil +import time from dataclasses import dataclass import torch import torch.distributed.checkpoint as dcp + +import torchstore as ts from torch.distributed.checkpoint.metadata import Metadata as DcpMeta logger = logging.getLogger(__name__) @@ -69,3 +73,54 @@ def extract_param_name(key: str) -> str: def get_dcp_whole_state_dict_key(policy_version: int) -> str: return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" + + +class WeightCleaner: + """Manages asynchronous cleanup of model weights across different policy versions. + + This class handles the deletion of old model weights by maintaining a list of + cleanup tasks and tracking the last deleted version to avoid redundant operations. + """ + + def __init__(self): + # we need to keep the task around to make sure it's not garbage collected + self._tasks = [] + self._last_deleted_version = -1 + + def _remove_done_tasks(self): + """Remove completed tasks from the task list to prevent memory leaks.""" + self._tasks = [task for task in self._tasks if not task.done()] + + def step(self, delete_up_to_version: int): + """Schedule deletion of weights for all versions up to the specified version. + + Args: + delete_up_to_version (int): The highest policy version to delete (inclusive). + All versions from last_deleted_version + 1 to this version will be deleted. + """ + self._remove_done_tasks() + if delete_up_to_version <= self._last_deleted_version: + return + for version in range(self._last_deleted_version + 1, delete_up_to_version + 1): + self._tasks.append(asyncio.create_task(drop_weights(version))) + self._last_deleted_version = delete_up_to_version + + async def wait(self): + """Wait for all scheduled deletion tasks to complete.""" + await asyncio.gather(*self._tasks) + + +async def drop_weights(version: int): + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + await asyncio.to_thread(dcp_handle.drop) + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + logger.info(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 0668f8eca..54781fab5 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -15,7 +15,6 @@ import torch import torch.nn.functional as F import torchstore as ts -from forge.actors._torchstore_utils import get_param_key from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer from forge.cli.config import parse @@ -25,6 +24,7 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce +from forge.util._torchstore import get_param_key from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 6a2e23fbf..0af642cfe 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -3,23 +3,29 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Test for src/forge/util/_torchstore.py.""" +import asyncio import os import tempfile import unittest from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import pytest - import torch import torch.distributed.checkpoint as dcp -from forge.actors._torchstore_utils import DcpHandle +from forge.util._torchstore import DcpHandle, WeightCleaner ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( r"ignore:.*torch.distributed" ) +ignore_coroutine_not_awaited = pytest.mark.filterwarnings( + "ignore:.*coroutine.*was never awaited.*" +) + class TestDcpHandle(unittest.TestCase): def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: @@ -59,3 +65,145 @@ def test_dcp_handle_drop_sets_none_for_manifold(self): self.assertEqual(handle.checkpoint_id, None) self.assertEqual(handle.metadata, None) self.assertEqual(handle.param_names, None) + + +class TestWeightCleaner(unittest.IsolatedAsyncioTestCase): + """Test suite for WeightCleaner class.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.cleaner = WeightCleaner() + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_completed_tasks(self): + """Test _remove_done_tasks removes completed tasks.""" + # Create mock tasks - some done, some not + done_task1 = MagicMock() + done_task1.done.return_value = True + + done_task2 = MagicMock() + done_task2.done.return_value = True + + pending_task = MagicMock() + pending_task.done.return_value = False + + self.cleaner._tasks = [done_task1, pending_task, done_task2] + self.cleaner._remove_done_tasks() + + # Only the pending task should remain + self.assertEqual(len(self.cleaner._tasks), 1) + self.assertEqual(self.cleaner._tasks[0], pending_task) + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_all_pending(self): + """Test _remove_done_tasks with all tasks pending.""" + pending_task1 = MagicMock() + pending_task1.done.return_value = False + + pending_task2 = MagicMock() + pending_task2.done.return_value = False + + self.cleaner._tasks = [pending_task1, pending_task2] + self.cleaner._remove_done_tasks() + + # All tasks should remain + self.assertEqual(len(self.cleaner._tasks), 2) + self.assertEqual(self.cleaner._tasks, [pending_task1, pending_task2]) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_equal_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version equals last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 5 first + self.cleaner.step(delete_up_to_version=5) + mock_drop_weights.assert_called() + mock_create_task.assert_called() + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_lower_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version is lower than last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 10 first + self.cleaner.step(delete_up_to_version=10) + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (lower than already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_tasks_initial_call( + self, mock_create_task, mock_drop_weights + ): + """Test step method creates tasks for entire version range.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Request deletion up to version 5 from initial state + self.cleaner.step(delete_up_to_version=5) + + # Should create 6 tasks (versions 0 through 5) + self.assertEqual(mock_create_task.call_count, 6) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_only_new_version_tasks( + self, mock_create_task, mock_drop_weights + ): + """Test step method only creates tasks for versions not yet deleted.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # First deletion up to version 3 + self.cleaner.step(delete_up_to_version=3) + + # Reset mock to track only new calls + mock_create_task.reset_mock() + + # Second deletion up to version 7 + self.cleaner.step(delete_up_to_version=7) + + # Should only create tasks for versions 4, 5, 6, 7 + self.assertEqual(mock_create_task.call_count, 4) diff --git a/tests/unit_tests/util/test_torchstore_utils.py b/tests/unit_tests/util/test_torchstore_utils.py new file mode 100644 index 000000000..67ffcfe41 --- /dev/null +++ b/tests/unit_tests/util/test_torchstore_utils.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import os +import tempfile +import unittest + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import torch +import torch.distributed.checkpoint as dcp +from forge.util._torchstore import DcpHandle, WeightCleaner + +ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( + r"ignore:.*torch.distributed" +) + +ignore_coroutine_not_awaited = pytest.mark.filterwarnings( + "ignore:.*coroutine.*was never awaited.*" +) + + +class TestDcpHandle(unittest.TestCase): + def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: + """Returns path to checkpoint and DcpHandle.""" + checkpoint_id = str(Path(test_dir) / "test_checkpoint_id") + state_dict = {"a": torch.rand(1, 1), "b": torch.rand(1, 1)} + metadata = dcp.save(checkpoint_id=checkpoint_id, state_dict=state_dict) + assert os.path.exists(checkpoint_id), "failed to set up test checkpoint" + return checkpoint_id, DcpHandle( + checkpoint_id=checkpoint_id, + metadata=metadata, + param_names=list(state_dict.keys()), + ) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_deletes(self): + with tempfile.TemporaryDirectory() as test_dir: + ckpt_path, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertFalse(os.path.exists(ckpt_path)) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none_for_manifold(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.checkpoint_id = "manifold://test_bucket/tree/test_path" + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None) + + +class TestWeightCleaner(unittest.IsolatedAsyncioTestCase): + """Test suite for WeightCleaner class.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.cleaner = WeightCleaner() + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_completed_tasks(self): + """Test _remove_done_tasks removes completed tasks.""" + # Create mock tasks - some done, some not + done_task1 = MagicMock() + done_task1.done.return_value = True + + done_task2 = MagicMock() + done_task2.done.return_value = True + + pending_task = MagicMock() + pending_task.done.return_value = False + + self.cleaner._tasks = [done_task1, pending_task, done_task2] + self.cleaner._remove_done_tasks() + + # Only the pending task should remain + self.assertEqual(len(self.cleaner._tasks), 1) + self.assertEqual(self.cleaner._tasks[0], pending_task) + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_all_pending(self): + """Test _remove_done_tasks with all tasks pending.""" + pending_task1 = MagicMock() + pending_task1.done.return_value = False + + pending_task2 = MagicMock() + pending_task2.done.return_value = False + + self.cleaner._tasks = [pending_task1, pending_task2] + self.cleaner._remove_done_tasks() + + # All tasks should remain + self.assertEqual(len(self.cleaner._tasks), 2) + self.assertEqual(self.cleaner._tasks, [pending_task1, pending_task2]) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_equal_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version equals last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 5 first + self.cleaner.step(delete_up_to_version=5) + mock_drop_weights.assert_called() + mock_create_task.assert_called() + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_lower_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version is lower than last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 10 first + self.cleaner.step(delete_up_to_version=10) + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (lower than already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_tasks_initial_call( + self, mock_create_task, mock_drop_weights + ): + """Test step method creates tasks for entire version range.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Request deletion up to version 5 from initial state + self.cleaner.step(delete_up_to_version=5) + + # Should create 6 tasks (versions 0 through 5) + self.assertEqual(mock_create_task.call_count, 6) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_only_new_version_tasks( + self, mock_create_task, mock_drop_weights + ): + """Test step method only creates tasks for versions not yet deleted.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # First deletion up to version 3 + self.cleaner.step(delete_up_to_version=3) + + # Reset mock to track only new calls + mock_create_task.reset_mock() + + # Second deletion up to version 7 + self.cleaner.step(delete_up_to_version=7) + + # Should only create tasks for versions 4, 5, 6, 7 + self.assertEqual(mock_create_task.call_count, 4)