Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.utils import _pytree as torch_pytree

from thunder.dynamo.utils import (
recompile_graph,
remove_empty_autocast,
CompilerType,
get_split_reasons_string,
Expand All @@ -21,6 +20,7 @@
default_filter,
default_optimizer,
input_to_example_input_meta,
convert_checkpoint_tags,
)
from thunder.dynamo.splitter import _splitter
from thunder.dynamo.benchmark_utils import ThunderCompileSpecification
Expand Down Expand Up @@ -135,9 +135,10 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor

remove_empty_autocast(gm)

# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
recompile_graph(gm)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recompiling here was added in commit 0338afe when we did not have the graph splitting logic. Now we break the graph down in the subsequent code, so no need for recompile.

# Convert tag_activation_checkpoint operators, which is meaningless in eager mode, to actual checkpoint calls
# This will not be needed when we have found a way to make tag_activation_checkpoint fall back to PyTorch's backend
# See https://github.com/Lightning-AI/lightning-thunder/issues/2539
convert_checkpoint_tags(gm)

# The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections
# and unsupported sections which are passed to `torch.compile(backend='inductor')`
Expand Down
3 changes: 0 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
get_nodes_in_unsupported_ctx_regions,
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
)
Expand Down Expand Up @@ -186,8 +185,6 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.utils.checkpoint.checkpoint is Thunder-tracible.


jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs)
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
Expand Down
61 changes: 10 additions & 51 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from torch.nn.modules.module import _addindent
from torch.utils.weak import TensorWeakRef
import torch.utils.checkpoint


if torch.distributed.is_available():
from torch.distributed.tensor import DTensor
Expand Down Expand Up @@ -119,9 +121,8 @@ class SubgraphInfo:

Attributes:
original_graph_module: The original graph module.
original_split_graph_module: The original split graph module before any transformations are applied.
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
and before any submodules are compiled by Thunder.
original_split_graph_module: The original split graph module before any transformations are applied,
before any submodules are compiled by Thunder.
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
thunder_compiled_fns: List of thunder optimized callables.
This could be :obj:`None` if there the graph module was not supported by thunder.
Expand Down Expand Up @@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
return False, split_reason

# The higher order function must be fully supported by Thunder
if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply):
if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply):
m = node.graph.owning_module
for arg_node in node.args:
if arg_node.op == "get_attr":
Expand Down Expand Up @@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder(
return example_input_meta_to_input(example_value)


def _checkpoint_function_converter(gm: torch.fx.GraphModule):
def convert_checkpoint_tags(gm: torch.fx.GraphModule):
"""
Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace.
Replaces tag_activation_checkpoint operators in-place with torch.utils.checkpoint.checkpoint calls.

Args:
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to mention that this function mutates the gm.

"""
for n in gm.graph.nodes:
# replace the torch operator in "call_function" node
if n.op == "call_function":
assert isinstance(n.target, Callable)
if n.target.__module__ in ("_operator", "builtins"):
continue
check(
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
)
with gm.graph.inserting_before(n):
thunder_node = gm.graph.call_function(
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
)
n.replace_all_uses_with(thunder_node)
gm.graph.erase_node(n)
else:
if n.op == "call_module":
raise RuntimeError(
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
)
gm.graph.lint()
recompile_graph(gm)


def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule):
"""
Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule.

Args:
gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function.
sub_gm: the GraphModule containing the checkpoint operator

Note:
The GraphModule of the checkpointed function is updated inplace
"""
for n in sub_gm.graph.nodes:
if n.op == "call_function":
if n.target in (torch.ops.higher_order.tag_activation_checkpoint,):
checkpoint_target_node = n.args[0]
if checkpoint_target_node.op == "get_attr":
function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target)
else:
function_module = getattr(gm, n.args[0].name)
_checkpoint_function_converter(function_module)
if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint:
n.target = torch.utils.checkpoint.checkpoint


def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down
31 changes: 31 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,37 @@ def find_target_module(model, target_module_name):
assert isinstance(n.target, Symbol) or callable(n.target)


@requiresCUDA
@pytest.mark.parametrize("op", [torch.sin, torch.sinc])
def test_checkpoint_memory_use(op):
import torch.utils.checkpoint as checkpoint

def fn(x):
return op(op(op(op(x))))

def checkpoint_fn(x):
return checkpoint.checkpoint(fn, x, use_reentrant=False)

initial_mem = torch.cuda.memory_allocated()

x = torch.randn((128, 128), device="cuda", requires_grad=True)
jfn = thunderfx(checkpoint_fn)
y = jfn(x)

peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem

y_ref = fn(x)
torch.testing.assert_close(y, y_ref)

if op == torch.sin:
assert peak_mem_usage == x.nbytes * 2
else:
assert peak_mem_usage == x.nbytes * 3
# Make sure the checkpointed region falled back to PyTorch
sinfo = jfn._backend.subgraph_infos[-1]
assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes)


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
Expand Down
Loading