Skip to content

Commit eb0f131

Browse files
authored
[Feat] activation checkpointing (#584)
1 parent ce9b591 commit eb0f131

File tree

3 files changed

+114
-9
lines changed

3 files changed

+114
-9
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class TrainingArgs(FastVideoArgs):
414414
lr_warmup_steps: int = 0
415415
max_grad_norm: float = 0.0
416416
gradient_checkpointing: bool = False
417+
gradient_checkpointing_type: str = "full"
417418
selective_checkpointing: float = 0.0
418419
allow_tf32: bool = False
419420
mixed_precision: str = ""
@@ -615,6 +616,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
615616
parser.add_argument("--gradient-checkpointing",
616617
action=StoreBoolean,
617618
help="Whether to use gradient checkpointing")
619+
parser.add_argument("--gradient-checkpointing-type",
620+
type=str,
621+
choices=["full", "ops", "block_skip"],
622+
default="full",
623+
help="Gradient checkpointing type")
618624
parser.add_argument("--selective-checkpointing",
619625
type=float,
620626
help="Selective checkpointing threshold")
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import collections
2+
from enum import Enum
3+
from typing import Optional
4+
5+
import torch
6+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
7+
checkpoint_wrapper)
8+
9+
TRANSFORMER_BLOCK_NAMES = [
10+
"blocks",
11+
"double_blocks",
12+
"single_blocks",
13+
"transformer_blocks",
14+
"temporal_transformer_blocks",
15+
"transformer_double_blocks",
16+
"transformer_single_blocks",
17+
]
18+
19+
20+
class CheckpointType(str, Enum):
21+
FULL = "full"
22+
OPS = "ops"
23+
BLOCK_SKIP = "block_skip"
24+
25+
26+
_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = {
27+
torch.ops.aten.mm.default,
28+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
29+
torch.ops.aten._scaled_dot_product_flash_attention.default,
30+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
31+
}
32+
33+
34+
def apply_activation_checkpointing(
35+
module: torch.nn.Module,
36+
checkpointing_type: str = CheckpointType.FULL,
37+
n_layer: int = 1) -> torch.nn.Module:
38+
if checkpointing_type == CheckpointType.FULL:
39+
module = _apply_activation_checkpointing_blocks(module)
40+
elif checkpointing_type == CheckpointType.OPS:
41+
module = _apply_activation_checkpointing_ops(
42+
module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS)
43+
elif checkpointing_type == CheckpointType.BLOCK_SKIP:
44+
module = _apply_activation_checkpointing_blocks(module, n_layer)
45+
else:
46+
raise ValueError(
47+
f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}"
48+
)
49+
return module
50+
51+
52+
def _apply_activation_checkpointing_blocks(
53+
module: torch.nn.Module,
54+
n_layer: Optional[int] = None) -> torch.nn.Module:
55+
for transformer_block_name in TRANSFORMER_BLOCK_NAMES:
56+
blocks: torch.nn.Module = getattr(module, transformer_block_name, None)
57+
if blocks is None:
58+
continue
59+
for index, (layer_id, block) in enumerate(blocks.named_children()):
60+
if n_layer is None or index % n_layer == 0:
61+
block = checkpoint_wrapper(block, preserve_rng_state=False)
62+
blocks.register_module(layer_id, block)
63+
return module
64+
65+
66+
def _apply_activation_checkpointing_ops(module: torch.nn.Module,
67+
ops) -> torch.nn.Module:
68+
from torch.utils.checkpoint import (CheckpointPolicy,
69+
create_selective_checkpoint_contexts)
70+
71+
def _get_custom_policy(meta: dict[str, int]) -> CheckpointPolicy:
72+
73+
def _custom_policy(ctx, func, *args, **kwargs):
74+
mode = "recompute" if ctx.is_recompute else "forward"
75+
mm_count_key = f"{mode}_mm_count"
76+
if func == torch.ops.aten.mm.default:
77+
meta[mm_count_key] += 1
78+
# Saves output of all compute ops, except every second mm
79+
to_save = func in ops and not (func == torch.ops.aten.mm.default
80+
and meta[mm_count_key] % 2 == 0)
81+
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE
82+
83+
return _custom_policy
84+
85+
def selective_checkpointing_context_fn():
86+
meta: dict[str, int] = collections.defaultdict(int)
87+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
88+
89+
return checkpoint_wrapper(module,
90+
context_fn=selective_checkpointing_context_fn,
91+
preserve_rng_state=False)

fastvideo/v1/training/training_pipeline.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from fastvideo.v1.logger import init_logger
3232
from fastvideo.v1.pipelines import (ComposedPipelineBase, ForwardBatch,
3333
TrainingBatch)
34+
from fastvideo.v1.training.activation_checkpoint import (
35+
apply_activation_checkpointing)
3436
from fastvideo.v1.training.training_utils import (
3537
clip_grad_norm_while_handling_failing_dtensor_cases,
3638
compute_density_for_timestep_sampling, get_sigmas, load_checkpoint,
@@ -83,6 +85,11 @@ def initialize_training_pipeline(self, training_args: TrainingArgs):
8385
self.transformer.requires_grad_(True)
8486
self.transformer.train()
8587

88+
if training_args.gradient_checkpointing:
89+
self.transformer = apply_activation_checkpointing(
90+
self.transformer,
91+
checkpointing_type=training_args.gradient_checkpointing_type)
92+
8693
noise_scheduler = self.modules["scheduler"]
8794
params_to_optimize = self.transformer.parameters()
8895
params_to_optimize = list(
@@ -309,17 +316,18 @@ def _transformer_forward_and_compute_loss(
309316
current_timestep=training_batch.current_timestep,
310317
attn_metadata=training_batch.attn_metadata):
311318
model_pred = self.transformer(**input_kwargs)
312-
if self.training_args.precondition_outputs:
313-
model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
314-
target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
319+
if self.training_args.precondition_outputs:
320+
model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
321+
target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
322+
323+
# make sure no implicit broadcasting happens
324+
assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
325+
loss = (torch.mean((model_pred.float() - target.float())**2) /
326+
self.training_args.gradient_accumulation_steps)
315327

316-
# make sure no implicit broadcasting happens
317-
assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
318-
loss = (torch.mean((model_pred.float() - target.float())**2) /
319-
self.training_args.gradient_accumulation_steps)
328+
loss.backward()
329+
avg_loss = loss.detach().clone()
320330

321-
loss.backward()
322-
avg_loss = loss.detach().clone()
323331
# logger.info(f"rank: {self.rank}, avg_loss: {avg_loss.item()}",
324332
# local_main_process_only=False)
325333
world_group = get_world_group()

0 commit comments

Comments
 (0)