diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index fd4c194d91..2ee582bba0 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from typing import Any from thunder.dynamo.utils import SubgraphInfo - from thunder.core.transform_common import Transform from thunder.core.trace import TraceCtx as Trace from os import PathLike from collections.abc import Callable diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..7e1e39533d 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING import copy from functools import partial +import itertools import torch from torch.fx.passes.split_module import split_module @@ -25,6 +26,19 @@ from collections.abc import Callable +def _copy_without_tensors(module: torch.nn.Module) -> torch.nn.Module: + """Clone a ``torch.nn.Module`` while sharing parameter and buffer tensors. + + ``copy.deepcopy`` on a ``GraphModule`` duplicates all parameters and buffers, + which can be extremely costly for large models. By populating the ``memo`` + argument with those tensors, we ensure the cloned module reuses the original + storage, dramatically reducing memory overhead during splitting. + """ + + memo = {id(t): t for t in itertools.chain(module.parameters(), module.buffers())} + return copy.deepcopy(module, memo) + + def _splitter( gm: torch.fx.GraphModule, thunder_jit: Callable, @@ -156,8 +170,8 @@ def callback(node) -> int: split_gm.graph.output(()) # If split_gm contains Parameters or Tensors then deepcopy would also create their copies. - # TODO: Eliminate deepcopy - original_split_gm = copy.deepcopy(split_gm) + # To avoid duplicating model weights, perform a deepcopy that shares tensors. + original_split_gm = _copy_without_tensors(split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index b2dd84fd0c..d76d0a484b 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -588,7 +588,6 @@ def test_disable_params_and_buffer_check(): @pytest.mark.parametrize("dynamic", (False, True)) def test_disable_params_check_thunderfx(dynamic: bool): - from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform from thunder.dynamo import thunderfx class Model(torch.nn.Module):