Skip to content

Conversation

casteryh
Copy link
Contributor

@casteryh casteryh commented Oct 2, 2025

#252
test run: https://meta.wandb.io/torchforge/grpo-training/runs/9epdrv7m

> ls forge_dcp_tmp/
policy_ver_0000000014.dcp_whole_state_dict

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 2, 2025
@casteryh casteryh requested a review from joecummings October 2, 2025 20:55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not put this in _torchstore_utils?

It seems that's where all the other weight sync information is. If that's not supposed to be the end location, I'd almost rather all _torchstore_utils be moved out into a weight_sync.py file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather all _torchstore_utils be moved out into a weight_sync.py

Maybe I will do this. Let me know wyt. Do you want me to make it _weight_sync.py instead?

)


async def drop_weights(version: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're here, I'd prefer a name like "delete_old_weights"

And instead of version, something like "oldest_version_to_keep"



async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log instead of print

# await drop_weights(training_step - 1)
# t.step("drop_weights")
if training_step >= 2:
await drop_weights(training_step - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be truly async or does it have to be blocking like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be truly async, if we just create a task and not await on it.

@casteryh casteryh changed the title factor out weight cleanup to separate file factor out weight cleanup to separate file, also non-blocking now Oct 7, 2025
@casteryh
Copy link
Contributor Author

casteryh commented Oct 7, 2025

image

Ended up doing a whole refactor. It does the weight clean up in the background in a non-blocking way now.
tested: unit tests and running grpo main (https://meta.wandb.io/torchforge/grpo-training/runs/r8q88uie)
ptal @joecummings

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asyncio looks reasonable to me. A good way to measure this would be to run this code for a while and always synchronize (finish your tasks) on weighsync step, and measure the delta between doing that and waiting until the next step

start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have a 'get_meta' in torchstore (although it's lacking a proper object).

matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation specific to DCP?

Do we need something like (ts.delete(r"key.*") support in torchstore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation specific to DCP?

Yes

Do we need something like (ts.delete(r"key.*") support in torchstore?

It would be good if we can have it. Although currently it is not a bottleneck to simply call delete on every key.

if training_step >= 2:
await drop_weights(training_step - 1)
t.step("drop_weights")
# weight cleanup is non-blocking, the task is executed in the background
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

Copy link
Contributor Author

@casteryh casteryh Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you confirming this all finishes before adding more weights?

I thought the point is you don't, if you just need the weight to be eventually deleted. when you do step(), the task is scheduled in the background and everything else proceeds as normal.

Also in typical async form this step would just be an async method that you'd await now or later.

Yes but in that case, if we want to schedule the task in the background and not await for it, we need to manage the task in main.py, which we supposedly don't want to do. This essentially hides the task scheduling logic in the WeightCleaner class.

Why is there an extra method called "wait"?

If you want to make sure all the scheduled tasks are indeed completed (i.e. all old weights are deleted. like you mentioned earliner), you can await weight_cleaner.wait(). Presumably this can be named better, let me know what you think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in typical async form this step would just be an async method that you'd await now or later. Why is there an extra method called "wait"?

My understanding is, in typical async code, if you don't explicitly create a task, then it will never get executed unless you await on it? I think we can also always schedule the task and return a join handle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to core app/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything or only the WeightCleaner? trainer and policy both need functions in this file.

@casteryh
Copy link
Contributor Author

ptal @joecummings

@casteryh casteryh requested a review from Jack-Khuu October 16, 2025 22:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants