-
Notifications
You must be signed in to change notification settings - Fork 107
Convert activation checkpointing tag with eager checkpointing function #2538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
5fc411f
f3e80b3
49f813b
5bb50af
35297a7
a1db659
71ee95e
abc191b
1b0f13b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
get_nodes_in_unsupported_ctx_regions, | ||
update_node_and_submodule, | ||
recompile_graph, | ||
checkpoint_converter, | ||
_get_example_inputs_from_placeholder, | ||
_ThunderSplitGraphModule, | ||
) | ||
|
@@ -186,8 +185,6 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: | |
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders | ||
) | ||
example_input_metadatas.append(list(example_input_metadata)) | ||
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators | ||
checkpoint_converter(split_gm, graph_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs) | ||
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
import torch | ||
from torch.nn.modules.module import _addindent | ||
from torch.utils.weak import TensorWeakRef | ||
import torch.utils.checkpoint | ||
|
||
|
||
if torch.distributed.is_available(): | ||
from torch.distributed.tensor import DTensor | ||
|
@@ -119,9 +121,8 @@ class SubgraphInfo: | |
|
||
Attributes: | ||
original_graph_module: The original graph module. | ||
original_split_graph_module: The original split graph module before any transformations are applied. | ||
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, | ||
and before any submodules are compiled by Thunder. | ||
original_split_graph_module: The original split graph module before any transformations are applied, | ||
before any submodules are compiled by Thunder. | ||
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules. | ||
thunder_compiled_fns: List of thunder optimized callables. | ||
This could be :obj:`None` if there the graph module was not supported by thunder. | ||
|
@@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | |
return False, split_reason | ||
|
||
# The higher order function must be fully supported by Thunder | ||
if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): | ||
if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply): | ||
m = node.graph.owning_module | ||
for arg_node in node.args: | ||
if arg_node.op == "get_attr": | ||
|
@@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder( | |
return example_input_meta_to_input(example_value) | ||
|
||
|
||
def _checkpoint_function_converter(gm: torch.fx.GraphModule): | ||
def convert_checkpoint_tags(gm: torch.fx.GraphModule): | ||
""" | ||
Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace. | ||
Replaces tag_activation_checkpoint operators with torch.utils.checkpoint.checkpoint calls. | ||
|
||
Args: | ||
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. | ||
tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to mention that this function mutates the |
||
""" | ||
for n in gm.graph.nodes: | ||
# replace the torch operator in "call_function" node | ||
if n.op == "call_function": | ||
assert isinstance(n.target, Callable) | ||
if n.target.__module__ in ("_operator", "builtins"): | ||
continue | ||
check( | ||
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" | ||
) | ||
with gm.graph.inserting_before(n): | ||
thunder_node = gm.graph.call_function( | ||
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs | ||
) | ||
n.replace_all_uses_with(thunder_node) | ||
gm.graph.erase_node(n) | ||
else: | ||
if n.op == "call_module": | ||
raise RuntimeError( | ||
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" | ||
) | ||
gm.graph.lint() | ||
recompile_graph(gm) | ||
|
||
|
||
def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule): | ||
""" | ||
Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule. | ||
|
||
Args: | ||
gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function. | ||
sub_gm: the GraphModule containing the checkpoint operator | ||
|
||
Note: | ||
The GraphModule of the checkpointed function is updated inplace | ||
""" | ||
for n in sub_gm.graph.nodes: | ||
if n.op == "call_function": | ||
if n.target in (torch.ops.higher_order.tag_activation_checkpoint,): | ||
checkpoint_target_node = n.args[0] | ||
if checkpoint_target_node.op == "get_attr": | ||
function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target) | ||
else: | ||
function_module = getattr(gm, n.args[0].name) | ||
_checkpoint_function_converter(function_module) | ||
if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint: | ||
n.target = torch.utils.checkpoint.checkpoint | ||
|
||
|
||
def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -844,6 +844,37 @@ def find_target_module(model, target_module_name): | |
assert isinstance(n.target, Symbol) or callable(n.target) | ||
|
||
|
||
@requiresCUDA | ||
@pytest.mark.parametrize("op", [torch.sin, torch.sinc]) | ||
def test_checkpoint_memory_use(op): | ||
import torch.utils.checkpoint as checkpoint | ||
|
||
def fn(x): | ||
return op(op(op(op(x)))) | ||
|
||
def checkpoint_fn(x): | ||
return checkpoint.checkpoint(fn, x, use_reentrant=False) | ||
|
||
initial_mem = torch.cuda.memory_allocated() | ||
|
||
x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True) | ||
|
||
jfn = thunderfx(checkpoint_fn) | ||
y = jfn(x) | ||
|
||
peak_mem_usage = torch.cuda.max_memory_allocated() - initial_mem | ||
|
||
y_ref = fn(x) | ||
torch.testing.assert_close(y, y_ref) | ||
|
||
if op == torch.sin: | ||
assert peak_mem_usage == x.nbytes * 2 | ||
else: | ||
assert peak_mem_usage == x.nbytes * 3 | ||
# Make sure the checkpointed region falled back to PyTorch | ||
sinfo = jfn._backend.subgraph_infos[-1] | ||
assert any(n.name.startswith("inductor") for n in sinfo.split_graph_module.graph.nodes) | ||
|
||
|
||
@instantiate( | ||
dtypes=NOTHING, | ||
executors=[DynamoThunderExecutor], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recompiling here was added in commit 0338afe when we did not have the graph splitting logic. Now we break the graph down in the subsequent code, so no need for recompile.