-
Notifications
You must be signed in to change notification settings - Fork 16
factor out weight cleanup to separate file, also non-blocking now #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
src/forge/util/weight_sync.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put this in _torchstore_utils?
It seems that's where all the other weight sync information is. If that's not supposed to be the end location, I'd almost rather all _torchstore_utils be moved out into a weight_sync.py file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather all _torchstore_utils be moved out into a weight_sync.py
Maybe I will do this. Let me know wyt. Do you want me to make it _weight_sync.py instead?
src/forge/util/weight_sync.py
Outdated
) | ||
|
||
|
||
async def drop_weights(version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we're here, I'd prefer a name like "delete_old_weights"
And instead of version, something like "oldest_version_to_keep"
src/forge/util/weight_sync.py
Outdated
|
||
|
||
async def drop_weights(version: int): | ||
print(f"Dropping weights @ version {version}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
src/forge/util/weight_sync.py
Outdated
for key in matching_keys: | ||
await ts.delete(key) | ||
elapsed = time.perf_counter() - start_time | ||
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log instead of print
apps/grpo/main.py
Outdated
# await drop_weights(training_step - 1) | ||
# t.step("drop_weights") | ||
if training_step >= 2: | ||
await drop_weights(training_step - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be truly async or does it have to be blocking like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be truly async, if we just create a task and not await on it.
![]() Ended up doing a whole refactor. It does the weight clean up in the background in a non-blocking way now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asyncio looks reasonable to me. A good way to measure this would be to run this code for a while and always synchronize (finish your tasks) on weighsync step, and measure the delta between doing that and waiting until the next step
start_time = time.perf_counter() | ||
prefix = get_param_prefix(version) | ||
matching_keys = await ts.keys(prefix) | ||
# TODO: once we have something like `get_meta()` in torchstore, we can just |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do have a 'get_meta' in torchstore (although it's lacking a proper object).
matching_keys = await ts.keys(prefix) | ||
# TODO: once we have something like `get_meta()` in torchstore, we can just | ||
# query the type of the object instead of relying on keys. | ||
dcp_key = get_dcp_whole_state_dict_key(version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implementation specific to DCP?
Do we need something like (ts.delete(r"key.*")
support in torchstore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implementation specific to DCP?
Yes
Do we need something like
(ts.delete(r"key.*")
support in torchstore?
It would be good if we can have it. Although currently it is not a bottleneck to simply call delete on every key.
if training_step >= 2: | ||
await drop_weights(training_step - 1) | ||
t.step("drop_weights") | ||
# weight cleanup is non-blocking, the task is executed in the background |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are you confirming this all finishes before adding more weights?
Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are you confirming this all finishes before adding more weights?
I thought the point is you don't, if you just need the weight to be eventually deleted. when you do step()
, the task is scheduled in the background and everything else proceeds as normal.
Also in typical async form this step would just be an async method that you'd await now or later.
Yes but in that case, if we want to schedule the task in the background and not await for it, we need to manage the task in main.py
, which we supposedly don't want to do. This essentially hides the task scheduling logic in the WeightCleaner class.
Why is there an extra method called "wait"?
If you want to make sure all the scheduled tasks are indeed completed (i.e. all old weights are deleted. like you mentioned earliner), you can await weight_cleaner.wait()
. Presumably this can be named better, let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?
My understanding is, in typical async code, if you don't explicitly create a task, then it will never get executed unless you await on it? I think we can also always schedule the task and return a join handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to core app/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything or only the WeightCleaner? trainer and policy both need functions in this file.
ptal @joecummings |
#252
test run: https://meta.wandb.io/torchforge/grpo-training/runs/9epdrv7m