Skip to content
4 changes: 4 additions & 0 deletions .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ jobs:
eval "$(ssh-agent -s)"
ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}'
python -m pip install git+ssh://[email protected]/meta-pytorch/torchstore.git
- name: Install torchtitan
run: |
pip install --pre torchtitan==0.1.0.dev20250826+cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install tyro
- name: Install dependencies
run: python -m pip install --no-build-isolation -e ".[dev]"
- name: Run unit tests with coverage
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ cover/
wandb/

assets/wheels/vllm*.whl

# DCP artifacts
model_state_dict/
50 changes: 50 additions & 0 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import os
import shutil
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
Expand Down Expand Up @@ -37,6 +39,42 @@
from forge.controller import ForgeActor
from forge.data.utils import batch_to_device

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def cleanup_old_weight_versions(
state_dict_key: str,
delim: str,
current_policy_version: int,
) -> None:
"""Delete all old weight versions except the current one.

Args:
state_dict_key: The base key for state dict storage
delim: The delimiter used between key and version
current_policy_version: The current policy version to keep
logger_func: Function to use for logging debug messages
"""
prefix = f"{state_dict_key}{delim}"
current_weights = f"{prefix}{current_policy_version}"

# Find all weight directories that match our pattern
parent_dir = os.path.dirname(prefix) or "."
if os.path.exists(parent_dir):
for item in os.listdir(parent_dir):
item_path = os.path.join(parent_dir, item)
if (
item.startswith(os.path.basename(prefix))
and item != os.path.basename(current_weights)
and os.path.isdir(item_path)
):
try:
shutil.rmtree(item_path, ignore_errors=True)
logger.debug(f"Removed old weights at {item_path}")
except OSError as e:
logger.debug(f"Error deleting {item_path}: {e}")


@dataclass
class RLTrainer(ForgeActor):
Expand All @@ -63,6 +101,7 @@ def __post_init__(self):
in monarch for now.

"""
super().__init__()
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
Expand Down Expand Up @@ -223,8 +262,19 @@ async def push_weights(self, policy_version: int) -> None:
key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
if self.use_dcp:

# TODO - DCP should probably be being saved to NFS explicitly?
# Right now it will only save everything locally
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
await ts.put(key, metadata)

# Delete old weight versions if they exist
if self.rank == 0:
cleanup_old_weight_versions(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know the policy isn't currently reading from these weights?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cleanup belongs to Policy since the Policy actor will know if an earlier version is still needed.
We can implement cleanup as an endpoint of Policy and we call it in the main loop to make sure everything is in sync.
(Later) once we figure out how to make the policy actors talk to each other after update_weights we can potentially move it to the inside of policy and do it automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HMM yeah good point. I think this requires more consideration. I don't want to introduce some tracking at this moment, maybe a fragile heuristic we can go with is "don't delete the last 2" since we're not going off policy more than 1? Can follow up more with #194 wdyt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HMM yeah good point. I think this requires more consideration. I don't want to introduce some tracking at this moment, maybe a fragile heuristic we can go with is "don't delete the last 2" since we're not going off policy more than 1? Can follow up more with #194 wdyt

Sounds good. add the heuristics now just to make sure nothing breaks and we can add proper evict logic later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok done!

state_dict_key=self.state_dict_key,
delim=DELIM,
current_policy_version=policy_version,
)
else:
await ts.put_state_dict(vllm_ready_hf_sd, key)
end_time = time.time()
Expand Down
78 changes: 78 additions & 0 deletions tests/unit_tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile
import unittest

from forge.actors.trainer import cleanup_old_weight_versions


class TestTrainerUtilities(unittest.TestCase):
def setUp(self):
"""Set up test environment with temporary directory."""
self.test_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.test_dir)

def test_cleanup_old_weight_versions_basic(self):
"""Test basic cleanup functionality."""
# Create test directory structure
state_dict_key = os.path.join(self.test_dir, "model")
delim = "__"

# Create some mock weight directories
old_version_1 = f"{state_dict_key}{delim}1"
old_version_2 = f"{state_dict_key}{delim}2"
current_version = f"{state_dict_key}{delim}3"
unrelated_dir = os.path.join(self.test_dir, "other_model__1")

for dir_path in [old_version_1, old_version_2, current_version, unrelated_dir]:
os.makedirs(dir_path)

# Run cleanup for version 3
cleanup_old_weight_versions(
state_dict_key=state_dict_key,
delim=delim,
current_policy_version=3,
)

# Check that old versions were deleted
self.assertFalse(os.path.exists(old_version_1))
self.assertFalse(os.path.exists(old_version_2))

# Check that current version and unrelated directories still exist
self.assertTrue(os.path.exists(current_version))
self.assertTrue(os.path.exists(unrelated_dir))

def test_cleanup_old_weight_versions_os_error(self):
"""Test error handling when deletion fails."""
# Create test directory structure
state_dict_key = os.path.join(self.test_dir, "model")
delim = "__"

old_version = f"{state_dict_key}{delim}1"
current_version = f"{state_dict_key}{delim}2"

os.makedirs(old_version)
os.makedirs(current_version)

# Make the old version directory read-only to simulate deletion failure
os.chmod(old_version, 0o444)

# Run cleanup
cleanup_old_weight_versions(
state_dict_key=state_dict_key,
delim=delim,
current_policy_version=2,
)
# Clean up by restoring permissions
if os.path.exists(old_version):
os.chmod(old_version, 0o755)


if __name__ == "__main__":
unittest.main()
Loading