|
12 | 12 | import torch
|
13 | 13 | from torch.nn.modules.module import _addindent
|
14 | 14 | from torch.utils.weak import TensorWeakRef
|
| 15 | +import torch.utils.checkpoint |
| 16 | + |
15 | 17 |
|
16 | 18 | if torch.distributed.is_available():
|
17 | 19 | from torch.distributed.tensor import DTensor
|
@@ -119,9 +121,8 @@ class SubgraphInfo:
|
119 | 121 |
|
120 | 122 | Attributes:
|
121 | 123 | original_graph_module: The original graph module.
|
122 |
| - original_split_graph_module: The original split graph module before any transformations are applied. |
123 |
| - Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, |
124 |
| - and before any submodules are compiled by Thunder. |
| 124 | + original_split_graph_module: The original split graph module before any transformations are applied, |
| 125 | + before any submodules are compiled by Thunder. |
125 | 126 | split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
|
126 | 127 | thunder_compiled_fns: List of thunder optimized callables.
|
127 | 128 | This could be :obj:`None` if there the graph module was not supported by thunder.
|
@@ -420,7 +421,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
|
420 | 421 | return False, split_reason
|
421 | 422 |
|
422 | 423 | # The higher order function must be fully supported by Thunder
|
423 |
| - if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): |
| 424 | + if target in (torch.utils.checkpoint.checkpoint, torch.ops.higher_order.autograd_function_apply): |
424 | 425 | m = node.graph.owning_module
|
425 | 426 | for arg_node in node.args:
|
426 | 427 | if arg_node.op == "get_attr":
|
@@ -630,57 +631,15 @@ def _get_example_inputs_from_placeholder(
|
630 | 631 | return example_input_meta_to_input(example_value)
|
631 | 632 |
|
632 | 633 |
|
633 |
| -def _checkpoint_function_converter(gm: torch.fx.GraphModule): |
| 634 | +def convert_checkpoint_tags(gm: torch.fx.GraphModule): |
634 | 635 | """
|
635 |
| - Replace PyTorch operators in ``gm`` representing a checkpointed function with corresponding Thunder operators. The input ``gm`` is modified inplace. |
| 636 | + Replaces tag_activation_checkpoint operators with torch.utils.checkpoint.checkpoint calls. |
636 | 637 |
|
637 |
| - Args: |
638 |
| - gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. |
| 638 | + tag_activation_checkpoint only marks nodes for recomputation within torch.compile but does not execute checkpointing itself. |
639 | 639 | """
|
640 | 640 | for n in gm.graph.nodes:
|
641 |
| - # replace the torch operator in "call_function" node |
642 |
| - if n.op == "call_function": |
643 |
| - assert isinstance(n.target, Callable) |
644 |
| - if n.target.__module__ in ("_operator", "builtins"): |
645 |
| - continue |
646 |
| - check( |
647 |
| - n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" |
648 |
| - ) |
649 |
| - with gm.graph.inserting_before(n): |
650 |
| - thunder_node = gm.graph.call_function( |
651 |
| - _torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs |
652 |
| - ) |
653 |
| - n.replace_all_uses_with(thunder_node) |
654 |
| - gm.graph.erase_node(n) |
655 |
| - else: |
656 |
| - if n.op == "call_module": |
657 |
| - raise RuntimeError( |
658 |
| - "Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" |
659 |
| - ) |
660 |
| - gm.graph.lint() |
661 |
| - recompile_graph(gm) |
662 |
| - |
663 |
| - |
664 |
| -def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule): |
665 |
| - """ |
666 |
| - Utility function to convert the GraphModule that uses activation checkpointing into a Thunder-traceable GraphModule. |
667 |
| -
|
668 |
| - Args: |
669 |
| - gm: The parent GraphModule containing the submodule(sub_gm), as well as the GraphModule of the checkpointed function. |
670 |
| - sub_gm: the GraphModule containing the checkpoint operator |
671 |
| -
|
672 |
| - Note: |
673 |
| - The GraphModule of the checkpointed function is updated inplace |
674 |
| - """ |
675 |
| - for n in sub_gm.graph.nodes: |
676 |
| - if n.op == "call_function": |
677 |
| - if n.target in (torch.ops.higher_order.tag_activation_checkpoint,): |
678 |
| - checkpoint_target_node = n.args[0] |
679 |
| - if checkpoint_target_node.op == "get_attr": |
680 |
| - function_module = getattr(checkpoint_target_node.graph.owning_module, checkpoint_target_node.target) |
681 |
| - else: |
682 |
| - function_module = getattr(gm, n.args[0].name) |
683 |
| - _checkpoint_function_converter(function_module) |
| 641 | + if n.op == "call_function" and n.target is torch.ops.higher_order.tag_activation_checkpoint: |
| 642 | + n.target = torch.utils.checkpoint.checkpoint |
684 | 643 |
|
685 | 644 |
|
686 | 645 | def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
0 commit comments