-
Notifications
You must be signed in to change notification settings - Fork 107
Convert activation checkpointing tag with eager checkpointing function #2538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This fixes #2501. on
^ happening on each rank in this PR (ac5508a)
|
|
||
# Dynamo uses lazy generation of the underlying Python code, so we need to | ||
# force recompilation of the GraphModule before passing it to Thunder. | ||
recompile_graph(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recompiling here was added in commit 0338afe when we did not have the graph splitting logic. Now we break the graph down in the subsequent code, so no need for recompile.
) | ||
example_input_metadatas.append(list(example_input_metadata)) | ||
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators | ||
checkpoint_converter(split_gm, graph_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.utils.checkpoint.checkpoint
is Thunder-tracible.
2bdfeae
to
49f813b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you @shino16
thunder/tests/test_dynamo.py
Outdated
|
||
initial_mem = torch.cuda.memory_allocated() | ||
|
||
x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to use smaller input as these tests run in parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's a good point, thank you!
Args: | ||
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. | ||
tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to mention that this function mutates the gm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shino16 for the fix, there are 2 tests test_checkpoint_converter
, test_checkpoint_converter_submodule
that test the old converter
, but I think we can keep them, as they also seem to validate the functionality of convert_checkpoint_tags
.
This is a workaround for #2527.
torch.ops.higher_order.tag_activation_checkpoint
does not perform activation checkpointing when run in eager mode, so we convert it back totorch.utils.checkpoint.checkpoint
.