|
| 1 | +import collections |
| 2 | +from enum import Enum |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 7 | + checkpoint_wrapper) |
| 8 | + |
| 9 | +TRANSFORMER_BLOCK_NAMES = [ |
| 10 | + "blocks", |
| 11 | + "double_blocks", |
| 12 | + "single_blocks", |
| 13 | + "transformer_blocks", |
| 14 | + "temporal_transformer_blocks", |
| 15 | + "transformer_double_blocks", |
| 16 | + "transformer_single_blocks", |
| 17 | +] |
| 18 | + |
| 19 | + |
| 20 | +class CheckpointType(str, Enum): |
| 21 | + FULL = "full" |
| 22 | + OPS = "ops" |
| 23 | + BLOCK_SKIP = "block_skip" |
| 24 | + |
| 25 | + |
| 26 | +_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { |
| 27 | + torch.ops.aten.mm.default, |
| 28 | + torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 29 | + torch.ops.aten._scaled_dot_product_flash_attention.default, |
| 30 | + torch.ops._c10d_functional.reduce_scatter_tensor.default, |
| 31 | +} |
| 32 | + |
| 33 | + |
| 34 | +def apply_activation_checkpointing( |
| 35 | + module: torch.nn.Module, |
| 36 | + checkpointing_type: str = CheckpointType.FULL, |
| 37 | + n_layer: int = 1) -> torch.nn.Module: |
| 38 | + if checkpointing_type == CheckpointType.FULL: |
| 39 | + module = _apply_activation_checkpointing_blocks(module) |
| 40 | + elif checkpointing_type == CheckpointType.OPS: |
| 41 | + module = _apply_activation_checkpointing_ops( |
| 42 | + module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) |
| 43 | + elif checkpointing_type == CheckpointType.BLOCK_SKIP: |
| 44 | + module = _apply_activation_checkpointing_blocks(module, n_layer) |
| 45 | + else: |
| 46 | + raise ValueError( |
| 47 | + f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" |
| 48 | + ) |
| 49 | + return module |
| 50 | + |
| 51 | + |
| 52 | +def _apply_activation_checkpointing_blocks( |
| 53 | + module: torch.nn.Module, |
| 54 | + n_layer: Optional[int] = None) -> torch.nn.Module: |
| 55 | + for transformer_block_name in TRANSFORMER_BLOCK_NAMES: |
| 56 | + blocks: torch.nn.Module = getattr(module, transformer_block_name, None) |
| 57 | + if blocks is None: |
| 58 | + continue |
| 59 | + for index, (layer_id, block) in enumerate(blocks.named_children()): |
| 60 | + if n_layer is None or index % n_layer == 0: |
| 61 | + block = checkpoint_wrapper(block, preserve_rng_state=False) |
| 62 | + blocks.register_module(layer_id, block) |
| 63 | + return module |
| 64 | + |
| 65 | + |
| 66 | +def _apply_activation_checkpointing_ops(module: torch.nn.Module, |
| 67 | + ops) -> torch.nn.Module: |
| 68 | + from torch.utils.checkpoint import (CheckpointPolicy, |
| 69 | + create_selective_checkpoint_contexts) |
| 70 | + |
| 71 | + def _get_custom_policy(meta: dict[str, int]) -> CheckpointPolicy: |
| 72 | + |
| 73 | + def _custom_policy(ctx, func, *args, **kwargs): |
| 74 | + mode = "recompute" if ctx.is_recompute else "forward" |
| 75 | + mm_count_key = f"{mode}_mm_count" |
| 76 | + if func == torch.ops.aten.mm.default: |
| 77 | + meta[mm_count_key] += 1 |
| 78 | + # Saves output of all compute ops, except every second mm |
| 79 | + to_save = func in ops and not (func == torch.ops.aten.mm.default |
| 80 | + and meta[mm_count_key] % 2 == 0) |
| 81 | + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE |
| 82 | + |
| 83 | + return _custom_policy |
| 84 | + |
| 85 | + def selective_checkpointing_context_fn(): |
| 86 | + meta: dict[str, int] = collections.defaultdict(int) |
| 87 | + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
| 88 | + |
| 89 | + return checkpoint_wrapper(module, |
| 90 | + context_fn=selective_checkpointing_context_fn, |
| 91 | + preserve_rng_state=False) |
0 commit comments