-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Use persistent workers if async saving is enabled #14465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,17 +98,18 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): | |
| Args: | ||
| checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be | ||
| of type AsyncCompatibleCheckpointIO. | ||
| persistent_workers (bool): whether to use persistent workers for checkpoint writing. Defaults to False. | ||
| Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn. | ||
| """ | ||
|
|
||
| def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: | ||
| def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO, persistent_workers: bool = False) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to double check in a clean installation but I was trying to quickly apply this patch/changeset to my branch. When using Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/opt/megatron-lm/megatron/core/dist_checkpointing/strategies/async_utils.py", line 448, in async_loop
item = queue.get()
^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/queues.py", line 103, in get
res = self._recv_bytes()
^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
buf = self._recv(4)
^^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/connection.py", line 395, in _recv
chunk = read(handle, remaining)
^^^^^^^^^^^^^^^^^^^^^^^Possible that my envirorment has some issues but just mentioning here so that you can double check.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm here: I tested this branch/PR with the nemo 25.09 and it hangs at the end of training when using async checkpointing + persistent workers. cc: @maanug-nv |
||
| if not HAVE_MEGATRON_CORE: | ||
| raise ImportError(IMPORT_ERROR) | ||
| if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): | ||
| raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') | ||
|
|
||
| super().__init__(checkpoint_io) | ||
| self.async_calls_queue = AsyncCallsQueue() | ||
| self.async_calls_queue = AsyncCallsQueue(persistent=persistent_workers) | ||
|
|
||
| def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: | ||
| """Executes async request returned from the underlying checkpoint_io asynchronously. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
Falseby default?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one concern now is persistent workers are consuming more memory, risking OOMs if async save is the default. that being said, all the recipes explicitly set this to True so until the memory regression is resolved, making sync save the default is the safer option