Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions megatron/core/dist_checkpointing/strategies/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class TemporalAsyncCaller(AsyncCaller):
def __init__(self):
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
self.preloaded_holder = None

@_disable_gc()
def schedule_async_call(self, async_req: AsyncRequest) -> None:
Expand All @@ -264,6 +265,7 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
# to do the defined action in `async_req.preload_fn` to
# stage GPU tensors to its defined destination
async_fn_args[1] = async_req.preload_fn()
self.preloaded_holder = async_fn_args[1]

rank = torch.distributed.get_rank()
start_sync = time()
Expand Down Expand Up @@ -339,6 +341,7 @@ def close(self, abort=False):
f"after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
self.preloaded_holder = None

def __del__(self):
pass
Expand Down
Loading