Skip to content

Commit 5756fd8

Browse files
committed
Replace tag_activation_checkpoint with actual checkpointer
1 parent 280c57e commit 5756fd8

File tree

3 files changed

+13
-58
lines changed

3 files changed

+13
-58
lines changed

thunder/dynamo/compiler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch.utils import _pytree as torch_pytree
1212

1313
from thunder.dynamo.utils import (
14-
recompile_graph,
1514
remove_empty_autocast,
1615
CompilerType,
1716
get_split_reasons_string,
@@ -21,6 +20,7 @@
2120
default_filter,
2221
default_optimizer,
2322
input_to_example_input_meta,
23+
convert_checkpoint_tags,
2424
)
2525
from thunder.dynamo.splitter import _splitter
2626
from thunder.dynamo.benchmark_utils import ThunderCompileSpecification
@@ -135,9 +135,8 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
135135

136136
remove_empty_autocast(gm)
137137

138-
# Dynamo uses lazy generation of the underlying Python code, so we need to
139-
# force recompilation of the GraphModule before passing it to Thunder.
140-
recompile_graph(gm)
138+
# Convert tag_activation_checkpoint operators, which is merely a tagger for torch.compile stack, to actual checkpoint calls
139+
convert_checkpoint_tags(gm)
141140

142141
# The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections
143142
# and unsupported sections which are passed to `torch.compile(backend='inductor')`

thunder/dynamo/splitter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
get_nodes_in_unsupported_ctx_regions,
1717
update_node_and_submodule,
1818
recompile_graph,
19-
checkpoint_converter,
2019
_get_example_inputs_from_placeholder,
2120
_ThunderSplitGraphModule,
2221
)
@@ -186,8 +185,6 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
186185
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
187186
)
188187
example_input_metadatas.append(list(example_input_metadata))
189-
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
190-
checkpoint_converter(split_gm, graph_module)
191188

192189
jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs)
193190
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names

thunder/dynamo/utils.py

Lines changed: 10 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch
1313
from torch.nn.modules.module import _addindent
1414
from torch.utils.weak import TensorWeakRef
15+
import torch.utils.checkpoint
16+
1517

1618
if torch.distributed.is_available():
1719
from torch.distributed.tensor import DTensor
@@ -119,9 +121,8 @@ class SubgraphInfo:
119121
120122
Attributes:
121123
original_graph_module: The original graph module.
122-
original_split_graph_module: The original split graph module before any transformations are applied.
123-
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
124-
and before any submodules are compiled by Thunder.
124+
original_split_graph_module: The original split graph module before any transformations are applied,
125+
before any submodules are compiled by Thunder.
125126
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
126127
thunder_compiled_fns: List of thunder optimized callables.
127128
This could be :obj:`None` if there the graph module was not supported by thunder.
@@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
420421
return False, split_reason
421422

422423
# The higher order function must be fully supported by Thunder
423-
if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply):
424+
if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply):
424425
m = node.graph.owning_module
425426
for arg_node in node.args:
426427
if arg_node.op == "get_attr":
@@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder(
630631
return example_input_meta_to_input(example_value)
631632

632633

633-
def _checkpoint_function_converter(gm: torch.fx.GraphModule):
634+
def convert_checkpoint_tags(gm: torch.fx.GraphModule):
634635
"""
635-
Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace.
636+
Replaces tag_activation_checkpoint operators with torch.utils.checkpoint.checkpoint calls.
636637
637-
Args:
638-
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
638+
tag_activation_checkpoint only marks nodes for recomputation within torch.compile but does not execute checkpointing itself.
639639
"""
640640
for n in gm.graph.nodes:
641-
# replace the torch operator in "call_function" node
642-
if n.op == "call_function":
643-
assert isinstance(n.target, Callable)
644-
if n.target.__module__ in ("_operator", "builtins"):
645-
continue
646-
check(
647-
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder"
648-
)
649-
with gm.graph.inserting_before(n):
650-
thunder_node = gm.graph.call_function(
651-
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs
652-
)
653-
n.replace_all_uses_with(thunder_node)
654-
gm.graph.erase_node(n)
655-
else:
656-
if n.op == "call_module":
657-
raise RuntimeError(
658-
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs"
659-
)
660-
gm.graph.lint()
661-
recompile_graph(gm)
662-
663-
664-
def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule):
665-
"""
666-
Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule.
667-
668-
Args:
669-
gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function.
670-
sub_gm: the GraphModule containing the checkpoint operator
671-
672-
Note:
673-
The GraphModule of the checkpointed function is updated inplace
674-
"""
675-
for n in sub_gm.graph.nodes:
676-
if n.op == "call_function":
677-
if n.target in (torch.ops.higher_order.tag_activation_checkpoint,):
678-
checkpoint_target_node = n.args[0]
679-
if checkpoint_target_node.op == "get_attr":
680-
function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target)
681-
else:
682-
function_module = getattr(gm, n.args[0].name)
683-
_checkpoint_function_converter(function_module)
641+
if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint:
642+
n.target = torch.utils.checkpoint.checkpoint
684643

685644

686645
def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:

0 commit comments

Comments
 (0)