diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..b01bffe07f 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -11,7 +11,6 @@ from torch.utils import _pytree as torch_pytree from thunder.dynamo.utils import ( - recompile_graph, remove_empty_autocast, CompilerType, get_split_reasons_string, @@ -21,6 +20,7 @@ default_filter, default_optimizer, input_to_example_input_meta, + convert_checkpoint_tags, ) from thunder.dynamo.splitter import _splitter from thunder.dynamo.benchmark_utils import ThunderCompileSpecification @@ -135,9 +135,10 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor remove_empty_autocast(gm) - # Dynamo uses lazy generation of the underlying Python code, so we need to - # force recompilation of the GraphModule before passing it to Thunder. - recompile_graph(gm) + # Convert tag_activation_checkpoint operators, which is meaningless in eager mode, to actual checkpoint calls + # This will not be needed when we have found a way to make tag_activation_checkpoint fall back to PyTorch's backend + # See https://github.com/Lightning-AI/lightning-thunder/issues/2539 + convert_checkpoint_tags(gm) # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections # and unsupported sections which are passed to `torch.compile(backend='inductor')` diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..d1d09d291a 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -16,7 +16,6 @@ get_nodes_in_unsupported_ctx_regions, update_node_and_submodule, recompile_graph, - checkpoint_converter, _get_example_inputs_from_placeholder, _ThunderSplitGraphModule, ) @@ -186,8 +185,6 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders ) example_input_metadatas.append(list(example_input_metadata)) - # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators - checkpoint_converter(split_gm, graph_module) jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs) # Update the node name from "submod_*" to "thunder_*" for more user-friendly names diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..abb567cb50 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -12,6 +12,8 @@ import torch from torch.nn.modules.module import _addindent from torch.utils.weak import TensorWeakRef +import torch.utils.checkpoint + if torch.distributed.is_available(): from torch.distributed.tensor import DTensor @@ -119,9 +121,8 @@ class SubgraphInfo: Attributes: original_graph_module: The original graph module. - original_split_graph_module: The original split graph module before any transformations are applied. - Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, - and before any submodules are compiled by Thunder. + original_split_graph_module: The original split graph module before any transformations are applied, + before any submodules are compiled by Thunder. split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules. thunder_compiled_fns: List of thunder optimized callables. This could be :obj:`None` if there the graph module was not supported by thunder. @@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason return False, split_reason # The higher order function must be fully supported by Thunder - if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): + if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply): m = node.graph.owning_module for arg_node in node.args: if arg_node.op == "get_attr": @@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder( return example_input_meta_to_input(example_value) -def _checkpoint_function_converter(gm: torch.fx.GraphModule): +def convert_checkpoint_tags(gm: torch.fx.GraphModule): """ - Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace. + Replaces tag_activation_checkpoint operators in-place with torch.utils.checkpoint.checkpoint calls. - Args: - gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. + tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode. """ for n in gm.graph.nodes: - # replace the torch operator in "call_function" node - if n.op == "call_function": - assert isinstance(n.target, Callable) - if n.target.__module__ in ("_operator", "builtins"): - continue - check( - n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" - ) - with gm.graph.inserting_before(n): - thunder_node = gm.graph.call_function( - _torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs - ) - n.replace_all_uses_with(thunder_node) - gm.graph.erase_node(n) - else: - if n.op == "call_module": - raise RuntimeError( - "Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" - ) - gm.graph.lint() - recompile_graph(gm) - - -def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule): - """ - Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule. - - Args: - gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function. - sub_gm: the GraphModule containing the checkpoint operator - - Note: - The GraphModule of the checkpointed function is updated inplace - """ - for n in sub_gm.graph.nodes: - if n.op == "call_function": - if n.target in (torch.ops.higher_order.tag_activation_checkpoint,): - checkpoint_target_node = n.args[0] - if checkpoint_target_node.op == "get_attr": - function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target) - else: - function_module = getattr(gm, n.args[0].name) - _checkpoint_function_converter(function_module) + if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint: + n.target = torch.utils.checkpoint.checkpoint def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 1a90e2fe89..134648e09c 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -844,6 +844,37 @@ def find_target_module(model, target_module_name): assert isinstance(n.target, Symbol) or callable(n.target) +@requiresCUDA +@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) +def test_checkpoint_memory_use(op): + import torch.utils.checkpoint as checkpoint + + def fn(x): + return op(op(op(op(x)))) + + def checkpoint_fn(x): + return checkpoint.checkpoint(fn, x, use_reentrant=False) + + initial_mem = torch.cuda.memory_allocated() + + x = torch.randn((128, 128), device="cuda", requires_grad=True) + jfn = thunderfx(checkpoint_fn) + y = jfn(x) + + peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem + + y_ref = fn(x) + torch.testing.assert_close(y, y_ref) + + if op == torch.sin: + assert peak_mem_usage == x.nbytes * 2 + else: + assert peak_mem_usage == x.nbytes * 3 + # Make sure the checkpointed region falled back to PyTorch + sinfo = jfn._backend.subgraph_infos[-1] + assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) + + @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor],