Skip to content

Conversation

shino16
Copy link
Collaborator

@shino16 shino16 commented Sep 25, 2025

This is a workaround for #2527. torch.ops.higher_order.tag_activation_checkpoint does not perform activation checkpointing when run in eager mode, so we convert it back to torch.utils.checkpoint.checkpoint.

@shino16
Copy link
Collaborator Author

shino16 commented Sep 25, 2025

This fixes #2501.

on main (280c57e)

[rank2]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 2 has a total capacity of 139.72 GiB of which 3.71 GiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 127.71 GiB is allocated by PyTorch, and 6.76 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

^ happening on each rank

in this PR (ac5508a)

Model name: Gemma-2-27b
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 46
Number of parameters: 3.55B
Distributed Mode: fsdp
Sharding Mode: zero3
Bucketing: block
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 9991.12 ms
Memory used: 81.76 GB
Tokens/s: 6555.90
Tokens/s/GPU: 819.49
TFLOP/s: 1192.38


# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
recompile_graph(gm)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recompiling here was added in commit 0338afe when we did not have the graph splitting logic. Now we break the graph down in the subsequent code, so no need for recompile.

)
example_input_metadatas.append(list(example_input_metadata))
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.utils.checkpoint.checkpoint is Thunder-tracible.

@shino16 shino16 force-pushed the inductor-checkpoint branch from 2bdfeae to 49f813b Compare September 26, 2025 11:44
@shino16 shino16 marked this pull request as ready for review September 26, 2025 11:46
Copy link
Collaborator

@KaelanDt KaelanDt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you @shino16


initial_mem = torch.cuda.memory_allocated()

x = torch.randn((1024 // 4, 1024, 1024), device="cuda", requires_grad=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use smaller input as these tests run in parallel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's a good point, thank you!

Args:
gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace.
tag_activation_checkpoint only marks nodes for torch.compile stack but does not execute actual checkpointing in eager mode.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to mention that this function mutates the gm.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @shino16.

Let's also wait for review from @kiya00

Copy link
Collaborator

@kiya00 kiya00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shino16 for the fix, there are 2 tests test_checkpoint_converter, test_checkpoint_converter_submodule that test the old converter, but I think we can keep them, as they also seem to validate the functionality of convert_checkpoint_tags.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants