Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of
'torch_dist' or 'zarr'. Defaults to 'torch_dist'.
ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead.
Defaults to True.
Defaults to False.
ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save
with PyTorch distributed format. Defaults to None.
ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves.
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
use_te_rng_tracker: bool = False,
use_sharp: bool = False,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = True,
ckpt_async_save: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why False by default?

Copy link
Copy Markdown
Collaborator Author

@ananthsub ananthsub Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one concern now is persistent workers are consuming more memory, risking OOMs if async save is the default. that being said, all the recipes explicitly set this to True so until the memory regression is resolved, making sync save the default is the safer option

ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure: bool = False,
ckpt_parallel_save: bool = True,
Expand Down
6 changes: 4 additions & 2 deletions nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def create_checkpoint_io(wrapping_ckpt_io=None, **kwargs):

if wrapping_ckpt_io:
checkpoint_io = wrapping_ckpt_io(checkpoint_io)
if kwargs.get("async_save", False):
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)

async_save = kwargs.get("async_save", False)
if async_save:
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io, persistent_workers=True)

return checkpoint_io

Expand Down
5 changes: 3 additions & 2 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,18 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO):
Args:
checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be
of type AsyncCompatibleCheckpointIO.
persistent_workers (bool): whether to use persistent workers for checkpoint writing. Defaults to False.
Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn.
"""

def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None:
def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO, persistent_workers: bool = False) -> None:
Copy link
Copy Markdown
Contributor

@pramodk pramodk Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to double check in a clean installation but I was trying to quickly apply this patch/changeset to my branch. When using persistent_workers=True with save_last=True in ModelCheckpoint, my test was stuck with the below stack trace:

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/megatron-lm/megatron/core/dist_checkpointing/strategies/async_utils.py", line 448, in async_loop
    item = queue.get()
           ^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 103, in get
    res = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 395, in _recv
    chunk = read(handle, remaining)
            ^^^^^^^^^^^^^^^^^^^^^^^

Possible that my envirorment has some issues but just mentioning here so that you can double check.

Copy link
Copy Markdown
Contributor

@pramodk pramodk Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm here: I tested this branch/PR with the nemo 25.09 and it hangs at the end of training when using async checkpointing + persistent workers. cc: @maanug-nv

if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR)
if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO):
raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}')

super().__init__(checkpoint_io)
self.async_calls_queue = AsyncCallsQueue()
self.async_calls_queue = AsyncCallsQueue(persistent=persistent_workers)

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Executes async request returned from the underlying checkpoint_io asynchronously.
Expand Down
Loading