-
Notifications
You must be signed in to change notification settings - Fork 16
factor out weight cleanup to separate file, also non-blocking now #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7643064
e66efee
c8b2b42
73d233c
cee8d29
0b5edb7
19e5100
cf5dab5
5fe537f
e91f66c
d517a9a
90a2dba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move to core app/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything or only the WeightCleaner? trainer and policy both need functions in this file. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,16 @@ | |
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import asyncio | ||
import logging | ||
import shutil | ||
import time | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
import torch.distributed.checkpoint as dcp | ||
|
||
import torchstore as ts | ||
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -69,3 +73,54 @@ def extract_param_name(key: str) -> str: | |
|
||
def get_dcp_whole_state_dict_key(policy_version: int) -> str: | ||
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" | ||
|
||
|
||
class WeightCleaner: | ||
"""Manages asynchronous cleanup of model weights across different policy versions. | ||
|
||
This class handles the deletion of old model weights by maintaining a list of | ||
cleanup tasks and tracking the last deleted version to avoid redundant operations. | ||
""" | ||
|
||
def __init__(self): | ||
# we need to keep the task around to make sure it's not garbage collected | ||
self._tasks = [] | ||
self._last_deleted_version = -1 | ||
|
||
def _remove_done_tasks(self): | ||
"""Remove completed tasks from the task list to prevent memory leaks.""" | ||
self._tasks = [task for task in self._tasks if not task.done()] | ||
|
||
def step(self, delete_up_to_version: int): | ||
"""Schedule deletion of weights for all versions up to the specified version. | ||
|
||
Args: | ||
delete_up_to_version (int): The highest policy version to delete (inclusive). | ||
All versions from last_deleted_version + 1 to this version will be deleted. | ||
""" | ||
self._remove_done_tasks() | ||
if delete_up_to_version <= self._last_deleted_version: | ||
return | ||
for version in range(self._last_deleted_version + 1, delete_up_to_version + 1): | ||
self._tasks.append(asyncio.create_task(drop_weights(version))) | ||
self._last_deleted_version = delete_up_to_version | ||
|
||
async def wait(self): | ||
"""Wait for all scheduled deletion tasks to complete.""" | ||
await asyncio.gather(*self._tasks) | ||
|
||
|
||
async def drop_weights(version: int): | ||
start_time = time.perf_counter() | ||
prefix = get_param_prefix(version) | ||
matching_keys = await ts.keys(prefix) | ||
# TODO: once we have something like `get_meta()` in torchstore, we can just | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we do have a 'get_meta' in torchstore (although it's lacking a proper object). |
||
# query the type of the object instead of relying on keys. | ||
dcp_key = get_dcp_whole_state_dict_key(version) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this implementation specific to DCP? Do we need something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes
It would be good if we can have it. Although currently it is not a bottleneck to simply call delete on every key. |
||
if dcp_key in matching_keys: | ||
dcp_handle = await ts.get(dcp_key) | ||
await asyncio.to_thread(dcp_handle.drop) | ||
for key in matching_keys: | ||
await ts.delete(key) | ||
elapsed = time.perf_counter() - start_time | ||
logger.info(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are you confirming this all finishes before adding more weights?
Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the point is you don't, if you just need the weight to be eventually deleted. when you do
step()
, the task is scheduled in the background and everything else proceeds as normal.Yes but in that case, if we want to schedule the task in the background and not await for it, we need to manage the task in
main.py
, which we supposedly don't want to do. This essentially hides the task scheduling logic in the WeightCleaner class.If you want to make sure all the scheduled tasks are indeed completed (i.e. all old weights are deleted. like you mentioned earliner), you can
await weight_cleaner.wait()
. Presumably this can be named better, let me know what you think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is, in typical async code, if you don't explicitly create a task, then it will never get executed unless you await on it? I think we can also always schedule the task and return a join handle.