Skip to content

Conversation

KAVYANSHTYAGI
Copy link

Description

Summary

This PR replaces a full copy.deepcopy of the split GraphModule with a lightweight clone that shares Parameter and buffer tensors with the original module. This avoids duplicating model weights during graph splitting and significantly reduces peak memory.

Motivation

copy.deepcopy(split_gm) duplicates every Parameter and registered buffer. For large models this can momentarily double memory usage and trigger OOMs. The splitter only needs a structural snapshot of the module; it does not mutate weights. Sharing tensor storage is therefore safe and much more memory-efficient.

What’s changed

Added _copy_without_tensors(module: nn.Module) -> nn.Module
Uses copy.deepcopy with a pre-populated memo so that all Parameters and registered buffers are reused rather than copied.

Replaced copy.deepcopy(split_gm) with _copy_without_tensors(split_gm) when capturing original_split_gm.

Minor: import itertools (for chain) and improved inline comments.

Implementation notes

The helper clones Python/FX structure while mapping each tensor id in parameters() and buffers() back to the original object:

memo = {id(t): t for t in itertools.chain(module.parameters(), module.buffers())}
clone = copy.deepcopy(module, memo)

Behavior of the splitter is unchanged; only memory characteristics improve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant