Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
if TYPE_CHECKING:
from typing import Any
from thunder.dynamo.utils import SubgraphInfo
from thunder.core.transform_common import Transform
from thunder.core.trace import TraceCtx as Trace
from os import PathLike
from collections.abc import Callable
Expand Down
18 changes: 16 additions & 2 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
import copy
from functools import partial
import itertools

import torch
from torch.fx.passes.split_module import split_module
Expand All @@ -25,6 +26,19 @@
from collections.abc import Callable


def _copy_without_tensors(module: torch.nn.Module) -> torch.nn.Module:
"""Clone a ``torch.nn.Module`` while sharing parameter and buffer tensors.

``copy.deepcopy`` on a ``GraphModule`` duplicates all parameters and buffers,
which can be extremely costly for large models. By populating the ``memo``
argument with those tensors, we ensure the cloned module reuses the original
storage, dramatically reducing memory overhead during splitting.
"""

memo = {id(t): t for t in itertools.chain(module.parameters(), module.buffers())}
return copy.deepcopy(module, memo)


def _splitter(
gm: torch.fx.GraphModule,
thunder_jit: Callable,
Expand Down Expand Up @@ -156,8 +170,8 @@ def callback(node) -> int:
split_gm.graph.output(())

# If split_gm contains Parameters or Tensors then deepcopy would also create their copies.
# TODO: Eliminate deepcopy
original_split_gm = copy.deepcopy(split_gm)
# To avoid duplicating model weights, perform a deepcopy that shares tensors.
original_split_gm = _copy_without_tensors(split_gm)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions
Expand Down
1 change: 0 additions & 1 deletion thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ def test_disable_params_and_buffer_check():

@pytest.mark.parametrize("dynamic", (False, True))
def test_disable_params_check_thunderfx(dynamic: bool):
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
from thunder.dynamo import thunderfx

class Model(torch.nn.Module):
Expand Down