Skip to content

Commit 2a72d20

Browse files
committed
grad checkpointing
1 parent f143b02 commit 2a72d20

File tree

1 file changed

+128
-52
lines changed

1 file changed

+128
-52
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 128 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Tuple, Union
15+
from typing import Any, Dict, Optional, Tuple, Union
1616

1717
import numpy as np
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
import torch.utils.checkpoint
2122

2223
from ...configuration_utils import ConfigMixin, register_to_config
2324
from ...utils import is_torch_version, logging
@@ -240,18 +241,51 @@ def __init__(
240241
self.resnets = nn.ModuleList(resnets)
241242

242243
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
243-
hidden_states = self.resnets[0](hidden_states)
244-
for attn, resnet in zip(self.attentions, self.resnets[1:]):
245-
if attn is not None:
246-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
247-
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
248-
attention_mask = prepare_causal_attention_mask(
249-
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
244+
if torch.is_grad_enabled() and self.gradient_checkpointing:
245+
246+
def create_custom_forward(module, return_dict=None):
247+
def custom_forward(*inputs):
248+
if return_dict is not None:
249+
return module(*inputs, return_dict=return_dict)
250+
else:
251+
return module(*inputs)
252+
253+
return custom_forward
254+
255+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
256+
257+
hidden_states = torch.utils.checkpoint.checkpoint(
258+
create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs
259+
)
260+
261+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
262+
if attn is not None:
263+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
264+
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
265+
attention_mask = prepare_causal_attention_mask(
266+
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
267+
)
268+
hidden_states = attn(hidden_states, attention_mask=attention_mask)
269+
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
270+
271+
hidden_states = torch.utils.checkpoint.checkpoint(
272+
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
250273
)
251-
hidden_states = attn(hidden_states, attention_mask=attention_mask)
252-
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
253274

254-
hidden_states = resnet(hidden_states)
275+
else:
276+
hidden_states = self.resnets[0](hidden_states)
277+
278+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
279+
if attn is not None:
280+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
281+
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
282+
attention_mask = prepare_causal_attention_mask(
283+
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
284+
)
285+
hidden_states = attn(hidden_states, attention_mask=attention_mask)
286+
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
287+
288+
hidden_states = resnet(hidden_states)
255289

256290
return hidden_states
257291

@@ -303,8 +337,26 @@ def __init__(
303337
self.downsamplers = None
304338

305339
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
306-
for resnet in self.resnets:
307-
hidden_states = resnet(hidden_states)
340+
if torch.is_grad_enabled() and self.gradient_checkpointing:
341+
342+
def create_custom_forward(module, return_dict=None):
343+
def custom_forward(*inputs):
344+
if return_dict is not None:
345+
return module(*inputs, return_dict=return_dict)
346+
else:
347+
return module(*inputs)
348+
349+
return custom_forward
350+
351+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
352+
353+
for resnet in self.resnets:
354+
hidden_states = torch.utils.checkpoint.checkpoint(
355+
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
356+
)
357+
else:
358+
for resnet in self.resnets:
359+
hidden_states = resnet(hidden_states)
308360

309361
if self.downsamplers is not None:
310362
for downsampler in self.downsamplers:
@@ -359,8 +411,27 @@ def __init__(
359411
self.upsamplers = None
360412

361413
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
362-
for resnet in self.resnets:
363-
hidden_states = resnet(hidden_states)
414+
if torch.is_grad_enabled() and self.gradient_checkpointing:
415+
416+
def create_custom_forward(module, return_dict=None):
417+
def custom_forward(*inputs):
418+
if return_dict is not None:
419+
return module(*inputs, return_dict=return_dict)
420+
else:
421+
return module(*inputs)
422+
423+
return custom_forward
424+
425+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
426+
427+
for resnet in self.resnets:
428+
hidden_states = torch.utils.checkpoint.checkpoint(
429+
create_custom_forward(resnet), hidden_states, **ckpt_kwargs
430+
)
431+
432+
else:
433+
for resnet in self.resnets:
434+
hidden_states = resnet(hidden_states)
364435

365436
if self.upsamplers is not None:
366437
for upsampler in self.upsamplers:
@@ -401,6 +472,9 @@ def __init__(
401472

402473
output_channel = block_out_channels[0]
403474
for i, down_block_type in enumerate(down_block_types):
475+
if down_block_type != "HunyuanVideoDownBlock3D":
476+
raise ValueError(f"Unsupported down_block_type: {down_block_type}")
477+
404478
input_channel = output_channel
405479
output_channel = block_out_channels[i]
406480
is_final_block = i == len(block_out_channels) - 1
@@ -454,27 +528,35 @@ def __init__(
454528
self.gradient_checkpointing = False
455529

456530
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
457-
# 1. Input layer
458531
hidden_states = self.conv_in(hidden_states)
459532

460-
use_reentrant = is_torch_version("<=", "1.11.0")
533+
if torch.is_grad_enabled() and self.gradient_checkpointing:
534+
535+
def create_custom_forward(module, return_dict=None):
536+
def custom_forward(*inputs):
537+
if return_dict is not None:
538+
return module(*inputs, return_dict=return_dict)
539+
else:
540+
return module(*inputs)
541+
542+
return custom_forward
543+
544+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
461545

462-
def create_block_forward(block):
463-
if torch.is_grad_enabled() and self.gradient_checkpointing:
464-
return lambda *inputs: torch.utils.checkpoint.checkpoint(
465-
lambda *x: block(*x), *inputs, use_reentrant=use_reentrant
546+
for down_block in self.down_blocks:
547+
hidden_states = torch.utils.checkpoint.checkpoint(
548+
create_custom_forward(down_block), hidden_states, **ckpt_kwargs
466549
)
467-
else:
468-
return block
469550

470-
# 2. Down blocks
471-
for down_block in self.down_blocks:
472-
hidden_states = create_block_forward(down_block)(hidden_states)
551+
hidden_states = torch.utils.checkpoint.checkpoint(
552+
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
553+
)
554+
else:
555+
for down_block in self.down_blocks:
556+
hidden_states = down_block(hidden_states)
473557

474-
# 3. Mid block
475-
hidden_states = self.mid_block(hidden_states)
558+
hidden_states = self.mid_block(hidden_states)
476559

477-
# 4. Output layers
478560
hidden_states = self.conv_norm_out(hidden_states)
479561
hidden_states = self.conv_act(hidden_states)
480562
hidden_states = self.conv_out(hidden_states)
@@ -501,7 +583,6 @@ def __init__(
501583
layers_per_block: int = 2,
502584
norm_num_groups: int = 32,
503585
act_fn: str = "silu",
504-
norm_type: str = "group",
505586
mid_block_add_attention=True,
506587
time_compression_ratio: int = 4,
507588
spatial_compression_ratio: int = 8,
@@ -527,6 +608,9 @@ def __init__(
527608
reversed_block_out_channels = list(reversed(block_out_channels))
528609
output_channel = reversed_block_out_channels[0]
529610
for i, up_block_type in enumerate(up_block_types):
611+
if up_block_type != "HunyuanVideoUpBlock3D":
612+
raise ValueError(f"Unsupported up_block_type: {up_block_type}")
613+
530614
prev_output_channel = output_channel
531615
output_channel = reversed_block_out_channels[i]
532616
is_final_block = i == len(block_out_channels) - 1
@@ -569,36 +653,30 @@ def __init__(
569653
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
570654
hidden_states = self.conv_in(hidden_states)
571655

572-
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
573-
if self.training and self.gradient_checkpointing:
656+
if torch.is_grad_enabled() and self.gradient_checkpointing:
574657

575-
def create_custom_forward(module):
658+
def create_custom_forward(module, return_dict=None):
576659
def custom_forward(*inputs):
577-
return module(*inputs)
660+
if return_dict is not None:
661+
return module(*inputs, return_dict=return_dict)
662+
else:
663+
return module(*inputs)
578664

579665
return custom_forward
580666

581-
# up
667+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
668+
669+
hidden_states = torch.utils.checkpoint.checkpoint(
670+
create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs
671+
)
672+
582673
for up_block in self.up_blocks:
583674
hidden_states = torch.utils.checkpoint.checkpoint(
584-
create_custom_forward(up_block),
585-
hidden_states,
586-
use_reentrant=False,
675+
create_custom_forward(up_block), hidden_states, **ckpt_kwargs
587676
)
588-
else:
589-
# middle
590-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
591-
hidden_states = hidden_states.to(upscale_dtype)
592-
593-
# up
594-
for up_block in self.up_blocks:
595-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
596677
else:
597-
# middle
598678
hidden_states = self.mid_block(hidden_states)
599-
hidden_states = hidden_states.to(upscale_dtype)
600679

601-
# up
602680
for up_block in self.up_blocks:
603681
hidden_states = up_block(hidden_states)
604682

@@ -643,7 +721,6 @@ def __init__(
643721
layers_per_block: int = 2,
644722
act_fn: str = "silu",
645723
norm_num_groups: int = 32,
646-
sample_tsize: int = 64,
647724
scaling_factor: float = 0.476986,
648725
spatial_compression_ratio: int = 8,
649726
temporal_compression_ratio: int = 4,
@@ -700,11 +777,10 @@ def __init__(
700777
self.use_framewise_encoding = True
701778
self.use_framewise_decoding = True
702779

703-
704780
# The minimal tile height and width for spatial tiling to be used
705781
self.tile_sample_min_height = 256
706782
self.tile_sample_min_width = 256
707-
783+
708784
# The minimal tile temporal batch size for temporal tiling to be used
709785
self.tile_sample_min_tsize = 64
710786

0 commit comments

Comments
 (0)