Skip to content
Merged
4 changes: 4 additions & 0 deletions .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ jobs:
eval "$(ssh-agent -s)"
ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}'
python -m pip install git+ssh://[email protected]/meta-pytorch/torchstore.git
- name: Install torchtitan
run: |
pip install --pre torchtitan==0.1.0.dev20250826+cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install tyro
- name: Install dependencies
run: python -m pip install --no-build-isolation -e ".[dev]"
- name: Run unit tests with coverage
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ cover/
wandb/

assets/wheels/vllm*.whl

# DCP artifacts
model_state_dict/
54 changes: 53 additions & 1 deletion src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import math
import os
import shutil
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
Expand Down Expand Up @@ -39,7 +40,46 @@
from forge.data.utils import batch_to_device

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG)


def cleanup_old_weight_versions(
state_dict_key: str,
delim: str,
current_policy_version: int,
) -> None:
"""Delete old weight versions, keeping only current and N-1 versions.

TODO - issues/194: provide a more robust way to handle eviction.

Args:
state_dict_key: The base key for state dict storage
delim: The delimiter used between key and version
current_policy_version: The current policy version to keep
"""
if current_policy_version <= 1:
return # No cleanup needed for versions 0 or 1

prefix = f"{state_dict_key}{delim}"
current_weights = f"{prefix}{current_policy_version}"
previous_weights = f"{prefix}{current_policy_version - 1}"

# Find all weight directories that match our pattern
parent_dir = os.path.dirname(prefix) or "."
if os.path.exists(parent_dir):
for item in os.listdir(parent_dir):
item_path = os.path.join(parent_dir, item)
if (
item.startswith(os.path.basename(prefix))
and item != os.path.basename(current_weights)
and item != os.path.basename(previous_weights)
and os.path.isdir(item_path)
):
try:
shutil.rmtree(item_path, ignore_errors=True)
logger.debug(f"Removed old weights at {item_path}")
except OSError as e:
logger.debug(f"Error deleting {item_path}: {e}")


@dataclass
Expand Down Expand Up @@ -67,6 +107,7 @@ def __post_init__(self):
in monarch for now.

"""
super().__init__()
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
Expand Down Expand Up @@ -228,8 +269,19 @@ async def push_weights(self, policy_version: int) -> None:
key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
if self.use_dcp:

# TODO - DCP should probably be being saved to NFS explicitly?
# Right now it will only save everything locally
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
await ts.put(key, metadata)

# Delete old weight versions if they exist
if self.rank == 0:
cleanup_old_weight_versions(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know the policy isn't currently reading from these weights?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cleanup belongs to Policy since the Policy actor will know if an earlier version is still needed.
We can implement cleanup as an endpoint of Policy and we call it in the main loop to make sure everything is in sync.
(Later) once we figure out how to make the policy actors talk to each other after update_weights we can potentially move it to the inside of policy and do it automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HMM yeah good point. I think this requires more consideration. I don't want to introduce some tracking at this moment, maybe a fragile heuristic we can go with is "don't delete the last 2" since we're not going off policy more than 1? Can follow up more with #194 wdyt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HMM yeah good point. I think this requires more consideration. I don't want to introduce some tracking at this moment, maybe a fragile heuristic we can go with is "don't delete the last 2" since we're not going off policy more than 1? Can follow up more with #194 wdyt

Sounds good. add the heuristics now just to make sure nothing breaks and we can add proper evict logic later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok done!

state_dict_key=self.state_dict_key,
delim=DELIM,
current_policy_version=policy_version,
)
else:
await ts.put_state_dict(vllm_ready_hf_sd, key)
end_time = time.time()
Expand Down
137 changes: 0 additions & 137 deletions tests/unit_tests/test_reference_actor.py

This file was deleted.

102 changes: 102 additions & 0 deletions tests/unit_tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile
import unittest

from forge.actors.trainer import cleanup_old_weight_versions


class TestTrainerUtilities(unittest.TestCase):
def setUp(self):
"""Set up test environment with temporary directory."""
self.test_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.test_dir)

def test_cleanup_old_weight_versions_basic(self):
"""Test basic cleanup functionality - keeps current and N-1 versions."""
# Create test directory structure
state_dict_key = os.path.join(self.test_dir, "model")
delim = "__"

# Create some mock weight directories
old_version_1 = f"{state_dict_key}{delim}1"
previous_version = f"{state_dict_key}{delim}2" # N-1 version
current_version = f"{state_dict_key}{delim}3" # Current version
unrelated_dir = os.path.join(self.test_dir, "other_model__1")

for dir_path in [
old_version_1,
previous_version,
current_version,
unrelated_dir,
]:
os.makedirs(dir_path)

# Run cleanup for version 3
cleanup_old_weight_versions(
state_dict_key=state_dict_key,
delim=delim,
current_policy_version=3,
)

# Check that only very old versions were deleted (version 1)
self.assertFalse(os.path.exists(old_version_1))

# Check that current and previous versions still exist
self.assertTrue(os.path.exists(previous_version)) # N-1 version should remain
self.assertTrue(
os.path.exists(current_version)
) # Current version should remain
self.assertTrue(os.path.exists(unrelated_dir)) # Unrelated dirs should remain

def test_cleanup_old_weight_versions_no_cleanup_version_1(self):
"""Test that no cleanup happens when current_policy_version <= 1."""
# Create test directory structure
state_dict_key = os.path.join(self.test_dir, "model")
delim = "__"

version_1 = f"{state_dict_key}{delim}1"
os.makedirs(version_1)

# Run cleanup for version 1 - should do nothing
cleanup_old_weight_versions(
state_dict_key=state_dict_key,
delim=delim,
current_policy_version=1,
)

# Version 1 should still exist
self.assertTrue(os.path.exists(version_1))

def test_cleanup_old_weight_versions_version_2(self):
"""Test cleanup with version 2 as current - should keep versions 1 and 2."""
# Create test directory structure
state_dict_key = os.path.join(self.test_dir, "model")
delim = "__"

version_1 = f"{state_dict_key}{delim}1" # N-1 version
version_2 = f"{state_dict_key}{delim}2" # Current version

for dir_path in [version_1, version_2]:
os.makedirs(dir_path)

# Run cleanup for version 2
cleanup_old_weight_versions(
state_dict_key=state_dict_key,
delim=delim,
current_policy_version=2,
)

# Both versions should still exist (no deletion for version 2)
self.assertTrue(os.path.exists(version_1))
self.assertTrue(os.path.exists(version_2))


if __name__ == "__main__":
unittest.main()
Loading