From 7643064ac41254e28740f5f3120e69ad797dfbf5 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 2 Oct 2025 11:10:28 -0700 Subject: [PATCH 1/8] add back drop_weights, factoring it out --- apps/grpo/main.py | 30 +++++------------------------- src/forge/util/weight_sync.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 25 deletions(-) create mode 100644 src/forge/util/weight_sync.py diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 852989682..21c06aa63 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -8,7 +8,6 @@ import asyncio -import time import uuid from dataclasses import dataclass from typing import Any, Callable @@ -17,10 +16,7 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset -from forge.actors._torchstore_utils import ( - get_dcp_whole_state_dict_key, - get_param_prefix, -) + from forge.actors.policy import Policy from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer @@ -33,6 +29,7 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.util.ops import compute_logprobs +from forge.util.weight_sync import drop_weights from monarch.actor import endpoint from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer @@ -289,23 +286,6 @@ async def pad_token(self): return self._tokenizer.pad_token_id -async def drop_weights(version: int): - print(f"Dropping weights @ version {version}") - start_time = time.perf_counter() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - # TODO: once we have something like `get_meta()` in torchstore, we can just - # query the type of the object instead of relying on keys. - dcp_key = get_dcp_whole_state_dict_key(version) - if dcp_key in matching_keys: - dcp_handle = await ts.get(dcp_key) - dcp_handle.drop() - for key in matching_keys: - await ts.delete(key) - elapsed = time.perf_counter() - start_time - print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") - - async def main(cfg: DictConfig): """Main GRPO training loop with rollout and training processes.""" group_size = cfg.group_size @@ -455,9 +435,9 @@ async def continuous_training(): await policy.update_weights.fanout(training_step) t.step("update_weights") - # if training_step >= 2: - # await drop_weights(training_step - 1) - # t.step("drop_weights") + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") t.stop() restart_tracer = True diff --git a/src/forge/util/weight_sync.py b/src/forge/util/weight_sync.py new file mode 100644 index 000000000..cdf2d6e83 --- /dev/null +++ b/src/forge/util/weight_sync.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torchstore as ts + +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) + + +async def drop_weights(version: int): + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") From e66efee33e8e8a943b9dfde2ec8dc65a781836b2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 2 Oct 2025 13:28:42 -0700 Subject: [PATCH 2/8] import time --- src/forge/util/weight_sync.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/forge/util/weight_sync.py b/src/forge/util/weight_sync.py index cdf2d6e83..771c5876b 100644 --- a/src/forge/util/weight_sync.py +++ b/src/forge/util/weight_sync.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time + import torchstore as ts from forge.actors._torchstore_utils import ( From c8b2b422ddf35748a959c78e282332a51767ffdc Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 6 Oct 2025 15:36:26 -0700 Subject: [PATCH 3/8] refactor --- apps/grpo/main.py | 7 +- src/forge/actors/_torchstore_utils.py | 56 ++++++++ src/forge/actors/policy.py | 34 ++--- src/forge/util/weight_sync.py | 31 ----- tests/unit_tests/test_torchstore_utils.py | 151 +++++++++++++++++++++- 5 files changed, 225 insertions(+), 54 deletions(-) delete mode 100644 src/forge/util/weight_sync.py diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 21c06aa63..84b8fc6ee 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset +from forge.actors._torchstore_utils import WeightCleaner from forge.actors.policy import Policy from forge.actors.reference_model import ReferenceModel @@ -29,7 +30,6 @@ from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.util.ops import compute_logprobs -from forge.util.weight_sync import drop_weights from monarch.actor import endpoint from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer @@ -407,6 +407,7 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 restart_tracer = True # Flag to control when to restart tracer + weight_cleaner = WeightCleaner() while True: # Restart tracer when needed (initial start or after completing a training step) @@ -435,9 +436,7 @@ async def continuous_training(): await policy.update_weights.fanout(training_step) t.step("update_weights") - if training_step >= 2: - await drop_weights(training_step - 1) - t.step("drop_weights") + weight_cleaner.step(training_step - 1) t.stop() restart_tracer = True diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py index bc0d55c3b..da074adb7 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/actors/_torchstore_utils.py @@ -3,12 +3,16 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging import shutil +import time from dataclasses import dataclass import torch import torch.distributed.checkpoint as dcp + +import torchstore as ts from torch.distributed.checkpoint.metadata import Metadata as DcpMeta logger = logging.getLogger(__name__) @@ -69,3 +73,55 @@ def extract_param_name(key: str) -> str: def get_dcp_whole_state_dict_key(policy_version: int) -> str: return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" + + +class WeightCleaner: + """Manages asynchronous cleanup of model weights across different policy versions. + + This class handles the deletion of old model weights by maintaining a list of + cleanup tasks and tracking the last deleted version to avoid redundant operations. + """ + + def __init__(self): + """Initialize the WeightCleaner with empty task list and reset deletion tracking.""" + # we need to keep the task around to make sure it's not garbage collected + self._tasks = [] + self._last_deleted_version = -1 + + def _remove_done_tasks(self): + """Remove completed tasks from the task list to prevent memory leaks.""" + self._tasks = [task for task in self._tasks if not task.done()] + + def step(self, delete_up_to_version: int): + """Schedule deletion of weights for all versions up to the specified version. + + Args: + delete_up_to_version (int): The highest policy version to delete (inclusive). + All versions from last_deleted_version + 1 to this version will be deleted. + """ + self._remove_done_tasks() + if delete_up_to_version <= self._last_deleted_version: + return + for version in range(self._last_deleted_version + 1, delete_up_to_version + 1): + self._tasks.append(asyncio.create_task(drop_weights(version))) + self._last_deleted_version = delete_up_to_version + + async def wait(self): + """Wait for all scheduled deletion tasks to complete.""" + await asyncio.gather(*self._tasks) + + +async def drop_weights(version: int): + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + logger.info(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 464674f2c..dcf4b1bd9 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -19,6 +19,23 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + extract_param_name, + get_dcp_whole_state_dict_key, + get_param_key, + get_param_prefix, + load_tensor_from_dcp, +) + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt +from forge.interfaces import Policy as PolicyInterface +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer +from forge.types import ProcessConfig + from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -43,23 +60,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors._torchstore_utils import ( - extract_param_name, - get_dcp_whole_state_dict_key, - get_param_key, - get_param_prefix, - load_tensor_from_dcp, -) - -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt -from forge.interfaces import Policy as PolicyInterface -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/src/forge/util/weight_sync.py b/src/forge/util/weight_sync.py deleted file mode 100644 index 771c5876b..000000000 --- a/src/forge/util/weight_sync.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import time - -import torchstore as ts - -from forge.actors._torchstore_utils import ( - get_dcp_whole_state_dict_key, - get_param_prefix, -) - - -async def drop_weights(version: int): - print(f"Dropping weights @ version {version}") - start_time = time.perf_counter() - prefix = get_param_prefix(version) - matching_keys = await ts.keys(prefix) - # TODO: once we have something like `get_meta()` in torchstore, we can just - # query the type of the object instead of relying on keys. - dcp_key = get_dcp_whole_state_dict_key(version) - if dcp_key in matching_keys: - dcp_handle = await ts.get(dcp_key) - dcp_handle.drop() - for key in matching_keys: - await ts.delete(key) - elapsed = time.perf_counter() - start_time - print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 6a2e23fbf..9f2abb039 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -4,22 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import os import tempfile import unittest from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import pytest - import torch import torch.distributed.checkpoint as dcp -from forge.actors._torchstore_utils import DcpHandle +from forge.actors._torchstore_utils import DcpHandle, WeightCleaner ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( r"ignore:.*torch.distributed" ) +ignore_coroutine_not_awaited = pytest.mark.filterwarnings( + "ignore:.*coroutine.*was never awaited.*" +) + class TestDcpHandle(unittest.TestCase): def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: @@ -59,3 +64,145 @@ def test_dcp_handle_drop_sets_none_for_manifold(self): self.assertEqual(handle.checkpoint_id, None) self.assertEqual(handle.metadata, None) self.assertEqual(handle.param_names, None) + + +class TestWeightCleaner(unittest.IsolatedAsyncioTestCase): + """Test suite for WeightCleaner class.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.cleaner = WeightCleaner() + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_completed_tasks(self): + """Test _remove_done_tasks removes completed tasks.""" + # Create mock tasks - some done, some not + done_task1 = MagicMock() + done_task1.done.return_value = True + + done_task2 = MagicMock() + done_task2.done.return_value = True + + pending_task = MagicMock() + pending_task.done.return_value = False + + self.cleaner._tasks = [done_task1, pending_task, done_task2] + self.cleaner._remove_done_tasks() + + # Only the pending task should remain + self.assertEqual(len(self.cleaner._tasks), 1) + self.assertEqual(self.cleaner._tasks[0], pending_task) + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_all_pending(self): + """Test _remove_done_tasks with all tasks pending.""" + pending_task1 = MagicMock() + pending_task1.done.return_value = False + + pending_task2 = MagicMock() + pending_task2.done.return_value = False + + self.cleaner._tasks = [pending_task1, pending_task2] + self.cleaner._remove_done_tasks() + + # All tasks should remain + self.assertEqual(len(self.cleaner._tasks), 2) + self.assertEqual(self.cleaner._tasks, [pending_task1, pending_task2]) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_equal_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version equals last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 5 first + self.cleaner.step(delete_up_to_version=5) + mock_drop_weights.assert_called() + mock_create_task.assert_called() + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_lower_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version is lower than last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 10 first + self.cleaner.step(delete_up_to_version=10) + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (lower than already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_tasks_initial_call( + self, mock_create_task, mock_drop_weights + ): + """Test step method creates tasks for entire version range.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Request deletion up to version 5 from initial state + self.cleaner.step(delete_up_to_version=5) + + # Should create 6 tasks (versions 0 through 5) + self.assertEqual(mock_create_task.call_count, 6) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_only_new_version_tasks( + self, mock_create_task, mock_drop_weights + ): + """Test step method only creates tasks for versions not yet deleted.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # First deletion up to version 3 + self.cleaner.step(delete_up_to_version=3) + + # Reset mock to track only new calls + mock_create_task.reset_mock() + + # Second deletion up to version 7 + self.cleaner.step(delete_up_to_version=7) + + # Should only create tasks for versions 4, 5, 6, 7 + self.assertEqual(mock_create_task.call_count, 4) From cee8d29bdd06ad0c89ffbcfdb25f24143ee442a0 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 7 Oct 2025 08:53:18 -0700 Subject: [PATCH 4/8] only keep one version --- apps/grpo/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index b66841cbd..ecf4f4448 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -444,7 +444,9 @@ async def continuous_training(): await policy.update_weights.fanout(training_step) t.step("update_weights") - weight_cleaner.step(training_step - 1) + # weight cleanup is non-blocking, the task is executed in the background + weight_cleaner.step(training_step) + t.step("weight_cleaner step") t.stop() restart_tracer = True From 0b5edb7293c0ceb57a176af05fc43c3646cd776f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Tue, 7 Oct 2025 08:55:18 -0700 Subject: [PATCH 5/8] move dcp_handle drop to separate thread --- src/forge/actors/_torchstore_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py index da074adb7..c77eeb54b 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/actors/_torchstore_utils.py @@ -83,7 +83,6 @@ class WeightCleaner: """ def __init__(self): - """Initialize the WeightCleaner with empty task list and reset deletion tracking.""" # we need to keep the task around to make sure it's not garbage collected self._tasks = [] self._last_deleted_version = -1 @@ -120,7 +119,7 @@ async def drop_weights(version: int): dcp_key = get_dcp_whole_state_dict_key(version) if dcp_key in matching_keys: dcp_handle = await ts.get(dcp_key) - dcp_handle.drop() + await asyncio.to_thread(dcp_handle.drop) for key in matching_keys: await ts.delete(key) elapsed = time.perf_counter() - start_time From 19e510048cd821e803ad60dba209353482f3584f Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 15 Oct 2025 15:23:07 -0700 Subject: [PATCH 6/8] Move torchstore_utils from actors to utils --- apps/grpo/main.py | 5 +---- src/forge/actors/generator.py | 16 ++++++++-------- src/forge/actors/trainer.py | 12 ++++++------ src/forge/util/__init__.py | 2 ++ .../_torchstore_utils.py => util/_torchstore.py} | 0 tests/sandbox/toy_rl/sumdigits.py | 2 +- tests/unit_tests/test_torchstore_utils.py | 2 +- 7 files changed, 19 insertions(+), 20 deletions(-) rename src/forge/{actors/_torchstore_utils.py => util/_torchstore.py} (100%) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5a6576d7e..0ece279d9 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -16,10 +16,6 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset -from forge.actors._torchstore_utils import ( - get_dcp_whole_state_dict_key, - get_param_prefix, -) from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer @@ -34,6 +30,7 @@ from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig +from forge.util._torchstore import get_dcp_whole_state_dict_key, get_param_prefix from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index ca934127e..2e22b9cef 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -40,14 +40,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors._torchstore_utils import ( - extract_param_name, - get_dcp_whole_state_dict_key, - get_param_key, - get_param_prefix, - load_tensor_from_dcp, -) - from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt @@ -56,6 +48,14 @@ from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._torchstore import ( + extract_param_name, + get_dcp_whole_state_dict_key, + get_param_key, + get_param_prefix, + load_tensor_from_dcp, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index dd85b3c82..0135ce233 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -38,18 +38,18 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - from forge.controller import ForgeActor from forge.data.utils import batch_to_device from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer +from forge.util._torchstore import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py index 5fb03b0f9..c4fcd4ca0 100644 --- a/src/forge/util/__init__.py +++ b/src/forge/util/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from . import _torchstore from .distributed import get_world_size_and_rank from .logging import get_logger, log_once, log_rank_zero from .metric_logging import get_metric_logger @@ -13,4 +14,5 @@ "log_once", "log_rank_zero", "get_metric_logger", + "_torchstore", ] diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/util/_torchstore.py similarity index 100% rename from src/forge/actors/_torchstore_utils.py rename to src/forge/util/_torchstore.py diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 0668f8eca..54781fab5 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -15,7 +15,6 @@ import torch import torch.nn.functional as F import torchstore as ts -from forge.actors._torchstore_utils import get_param_key from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer from forge.cli.config import parse @@ -25,6 +24,7 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce +from forge.util._torchstore import get_param_key from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 6a2e23fbf..3ddafdf1b 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -14,7 +14,7 @@ import torch import torch.distributed.checkpoint as dcp -from forge.actors._torchstore_utils import DcpHandle +from forge.util._torchstore import DcpHandle ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( r"ignore:.*torch.distributed" From d517a9a62b9de17d83c13351c54f6f6408e65c89 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 16 Oct 2025 13:02:26 -0700 Subject: [PATCH 7/8] fix test --- tests/unit_tests/test_torchstore_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 9f0dad12e..67ffcfe41 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -111,7 +111,7 @@ def test_remove_done_tasks_with_all_pending(self): @ignore_coroutine_not_awaited @pytest.mark.asyncio - @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) @patch("asyncio.create_task") async def test_step_no_cleanup_needed_equal_version( self, mock_create_task, mock_drop_weights @@ -139,7 +139,7 @@ async def test_step_no_cleanup_needed_equal_version( @ignore_coroutine_not_awaited @pytest.mark.asyncio - @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) @patch("asyncio.create_task") async def test_step_no_cleanup_needed_lower_version( self, mock_create_task, mock_drop_weights @@ -165,7 +165,7 @@ async def test_step_no_cleanup_needed_lower_version( @ignore_coroutine_not_awaited @pytest.mark.asyncio - @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) @patch("asyncio.create_task") async def test_step_creates_tasks_initial_call( self, mock_create_task, mock_drop_weights @@ -184,7 +184,7 @@ async def test_step_creates_tasks_initial_call( @ignore_coroutine_not_awaited @pytest.mark.asyncio - @patch("forge.actors._torchstore_utils.drop_weights", new_callable=AsyncMock) + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) @patch("asyncio.create_task") async def test_step_creates_only_new_version_tasks( self, mock_create_task, mock_drop_weights From 90a2dba13a2767dcec6c352b66ae5f0ae8f386b2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 16 Oct 2025 13:07:49 -0700 Subject: [PATCH 8/8] clean up --- tests/unit_tests/test_torchstore_utils.py | 1 + .../unit_tests/util/test_torchstore_utils.py | 208 ++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 tests/unit_tests/util/test_torchstore_utils.py diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 67ffcfe41..0af642cfe 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Test for src/forge/util/_torchstore.py.""" import asyncio import os diff --git a/tests/unit_tests/util/test_torchstore_utils.py b/tests/unit_tests/util/test_torchstore_utils.py new file mode 100644 index 000000000..67ffcfe41 --- /dev/null +++ b/tests/unit_tests/util/test_torchstore_utils.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import os +import tempfile +import unittest + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import torch +import torch.distributed.checkpoint as dcp +from forge.util._torchstore import DcpHandle, WeightCleaner + +ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( + r"ignore:.*torch.distributed" +) + +ignore_coroutine_not_awaited = pytest.mark.filterwarnings( + "ignore:.*coroutine.*was never awaited.*" +) + + +class TestDcpHandle(unittest.TestCase): + def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: + """Returns path to checkpoint and DcpHandle.""" + checkpoint_id = str(Path(test_dir) / "test_checkpoint_id") + state_dict = {"a": torch.rand(1, 1), "b": torch.rand(1, 1)} + metadata = dcp.save(checkpoint_id=checkpoint_id, state_dict=state_dict) + assert os.path.exists(checkpoint_id), "failed to set up test checkpoint" + return checkpoint_id, DcpHandle( + checkpoint_id=checkpoint_id, + metadata=metadata, + param_names=list(state_dict.keys()), + ) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_deletes(self): + with tempfile.TemporaryDirectory() as test_dir: + ckpt_path, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertFalse(os.path.exists(ckpt_path)) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none_for_manifold(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.checkpoint_id = "manifold://test_bucket/tree/test_path" + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None) + + +class TestWeightCleaner(unittest.IsolatedAsyncioTestCase): + """Test suite for WeightCleaner class.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.cleaner = WeightCleaner() + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_completed_tasks(self): + """Test _remove_done_tasks removes completed tasks.""" + # Create mock tasks - some done, some not + done_task1 = MagicMock() + done_task1.done.return_value = True + + done_task2 = MagicMock() + done_task2.done.return_value = True + + pending_task = MagicMock() + pending_task.done.return_value = False + + self.cleaner._tasks = [done_task1, pending_task, done_task2] + self.cleaner._remove_done_tasks() + + # Only the pending task should remain + self.assertEqual(len(self.cleaner._tasks), 1) + self.assertEqual(self.cleaner._tasks[0], pending_task) + + @ignore_coroutine_not_awaited + def test_remove_done_tasks_with_all_pending(self): + """Test _remove_done_tasks with all tasks pending.""" + pending_task1 = MagicMock() + pending_task1.done.return_value = False + + pending_task2 = MagicMock() + pending_task2.done.return_value = False + + self.cleaner._tasks = [pending_task1, pending_task2] + self.cleaner._remove_done_tasks() + + # All tasks should remain + self.assertEqual(len(self.cleaner._tasks), 2) + self.assertEqual(self.cleaner._tasks, [pending_task1, pending_task2]) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_equal_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version equals last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 5 first + self.cleaner.step(delete_up_to_version=5) + mock_drop_weights.assert_called() + mock_create_task.assert_called() + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_no_cleanup_needed_lower_version( + self, mock_create_task, mock_drop_weights + ): + """Test step method when delete_up_to_version is lower than last_deleted_version.""" + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Step it to 10 first + self.cleaner.step(delete_up_to_version=10) + + # Reset mock state to clear call history + mock_drop_weights.reset_mock() + mock_create_task.reset_mock() + + # Request deletion up to version 5 (lower than already deleted) + self.cleaner.step(delete_up_to_version=5) + + # No tasks should be created + mock_create_task.assert_not_called() + mock_drop_weights.assert_not_called() + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_tasks_initial_call( + self, mock_create_task, mock_drop_weights + ): + """Test step method creates tasks for entire version range.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # Request deletion up to version 5 from initial state + self.cleaner.step(delete_up_to_version=5) + + # Should create 6 tasks (versions 0 through 5) + self.assertEqual(mock_create_task.call_count, 6) + + @ignore_coroutine_not_awaited + @pytest.mark.asyncio + @patch("forge.util._torchstore.drop_weights", new_callable=AsyncMock) + @patch("asyncio.create_task") + async def test_step_creates_only_new_version_tasks( + self, mock_create_task, mock_drop_weights + ): + """Test step method only creates tasks for versions not yet deleted.""" + + future = asyncio.Future() + future.set_result(None) + mock_create_task.return_value = future + + # First deletion up to version 3 + self.cleaner.step(delete_up_to_version=3) + + # Reset mock to track only new calls + mock_create_task.reset_mock() + + # Second deletion up to version 7 + self.cleaner.step(delete_up_to_version=7) + + # Should only create tasks for versions 4, 5, 6, 7 + self.assertEqual(mock_create_task.call_count, 4)