-
Notifications
You must be signed in to change notification settings - Fork 107
Convert activation checkpointing tag with eager checkpointing function #2538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shino16
wants to merge
7
commits into
main
Choose a base branch
from
inductor-checkpoint
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
5fc411f
Replace tag_activation_checkpoint with actual checkpointer
shino16 f3e80b3
Add test
shino16 49f813b
Improved comments
shino16 5bb50af
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
shino16 35297a7
Resolve review: better comments
shino16 a1db659
Resolve review: small tensor for tests
shino16 71ee95e
Merge branch 'main' into inductor-checkpoint
t-vi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
get_nodes_in_unsupported_ctx_regions, | ||
update_node_and_submodule, | ||
recompile_graph, | ||
checkpoint_converter, | ||
_get_example_inputs_from_placeholder, | ||
_ThunderSplitGraphModule, | ||
) | ||
|
@@ -186,8 +185,6 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: | |
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders | ||
) | ||
example_input_metadatas.append(list(example_input_metadata)) | ||
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators | ||
checkpoint_converter(split_gm, graph_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs) | ||
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
import torch | ||
from torch.nn.modules.module import _addindent | ||
from torch.utils.weak import TensorWeakRef | ||
import torch.utils.checkpoint | ||
|
||
|
||
if torch.distributed.is_available(): | ||
from torch.distributed.tensor import DTensor | ||
|
@@ -119,9 +121,8 @@ class SubgraphInfo: | |
|
||
Attributes: | ||
original_graph_module: The original graph module. | ||
original_split_graph_module: The original split graph module before any transformations are applied. | ||
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, | ||
and before any submodules are compiled by Thunder. | ||
original_split_graph_module: The original split graph module before any transformations are applied, | ||
before any submodules are compiled by Thunder. | ||
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules. | ||
thunder_compiled_fns: List of thunder optimized callables. | ||
This could be :obj:`None` if there the graph module was not supported by thunder. | ||
|
@@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | |
return False, split_reason | ||
|
||
# The higher order function must be fully supported by Thunder | ||
if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): | ||
if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply): | ||
m = node.graph.owning_module | ||
for arg_node in node.args: | ||
if arg_node.op == "get_attr": | ||
|
@@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder( | |
return example_input_meta_to_input(example_value) | ||
|
||
|
||
def _checkpoint_function_converter(gm: torch.fx.GraphModule): | ||
def convert_checkpoint_tags(gm: torch.fx.GraphModule): | ||
""" | ||
Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace. | ||
Replaces tag_activation_checkpoint operators in-place with torch.utils.checkpoint.checkpoint calls. | ||
|
||
Args: | ||
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. | ||
tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to mention that this function mutates the |
||
""" | ||
for n in gm.graph.nodes: | ||
# replace the torch operator in "call_function" node | ||
if n.op == "call_function": | ||
assert isinstance(n.target, Callable) | ||
if n.target.__module__ in ("_operator", "builtins"): | ||
continue | ||
check( | ||
n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" | ||
) | ||
with gm.graph.inserting_before(n): | ||
thunder_node = gm.graph.call_function( | ||
_torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs | ||
) | ||
n.replace_all_uses_with(thunder_node) | ||
gm.graph.erase_node(n) | ||
else: | ||
if n.op == "call_module": | ||
raise RuntimeError( | ||
"Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" | ||
) | ||
gm.graph.lint() | ||
recompile_graph(gm) | ||
|
||
|
||
def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule): | ||
""" | ||
Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule. | ||
|
||
Args: | ||
gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function. | ||
sub_gm: the GraphModule containing the checkpoint operator | ||
|
||
Note: | ||
The GraphModule of the checkpointed function is updated inplace | ||
""" | ||
for n in sub_gm.graph.nodes: | ||
if n.op == "call_function": | ||
if n.target in (torch.ops.higher_order.tag_activation_checkpoint,): | ||
checkpoint_target_node = n.args[0] | ||
if checkpoint_target_node.op == "get_attr": | ||
function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target) | ||
else: | ||
function_module = getattr(gm, n.args[0].name) | ||
_checkpoint_function_converter(function_module) | ||
if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint: | ||
n.target = torch.utils.checkpoint.checkpoint | ||
|
||
|
||
def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recompiling here was added in commit 0338afe when we did not have the graph splitting logic. Now we break the graph down in the subsequent code, so no need for recompile.