1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Optional , Tuple , Union
15+ from typing import Any , Dict , Optional , Tuple , Union
1616
1717import numpy as np
1818import torch
1919import torch .nn as nn
2020import torch .nn .functional as F
21+ import torch .utils .checkpoint
2122
2223from ...configuration_utils import ConfigMixin , register_to_config
2324from ...utils import is_torch_version , logging
@@ -240,18 +241,51 @@ def __init__(
240241 self .resnets = nn .ModuleList (resnets )
241242
242243 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
243- hidden_states = self .resnets [0 ](hidden_states )
244- for attn , resnet in zip (self .attentions , self .resnets [1 :]):
245- if attn is not None :
246- batch_size , num_channels , num_frames , height , width = hidden_states .shape
247- hidden_states = hidden_states .permute (0 , 2 , 3 , 4 , 1 ).flatten (1 , 3 )
248- attention_mask = prepare_causal_attention_mask (
249- num_frames , height * width , hidden_states .dtype , hidden_states .device , batch_size = batch_size
244+ if torch .is_grad_enabled () and self .gradient_checkpointing :
245+
246+ def create_custom_forward (module , return_dict = None ):
247+ def custom_forward (* inputs ):
248+ if return_dict is not None :
249+ return module (* inputs , return_dict = return_dict )
250+ else :
251+ return module (* inputs )
252+
253+ return custom_forward
254+
255+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
256+
257+ hidden_states = torch .utils .checkpoint .checkpoint (
258+ create_custom_forward (self .resnets [0 ]), hidden_states , ** ckpt_kwargs
259+ )
260+
261+ for attn , resnet in zip (self .attentions , self .resnets [1 :]):
262+ if attn is not None :
263+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
264+ hidden_states = hidden_states .permute (0 , 2 , 3 , 4 , 1 ).flatten (1 , 3 )
265+ attention_mask = prepare_causal_attention_mask (
266+ num_frames , height * width , hidden_states .dtype , hidden_states .device , batch_size = batch_size
267+ )
268+ hidden_states = attn (hidden_states , attention_mask = attention_mask )
269+ hidden_states = hidden_states .unflatten (1 , (num_frames , height , width )).permute (0 , 4 , 1 , 2 , 3 )
270+
271+ hidden_states = torch .utils .checkpoint .checkpoint (
272+ create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs
250273 )
251- hidden_states = attn (hidden_states , attention_mask = attention_mask )
252- hidden_states = hidden_states .unflatten (1 , (num_frames , height , width )).permute (0 , 4 , 1 , 2 , 3 )
253274
254- hidden_states = resnet (hidden_states )
275+ else :
276+ hidden_states = self .resnets [0 ](hidden_states )
277+
278+ for attn , resnet in zip (self .attentions , self .resnets [1 :]):
279+ if attn is not None :
280+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
281+ hidden_states = hidden_states .permute (0 , 2 , 3 , 4 , 1 ).flatten (1 , 3 )
282+ attention_mask = prepare_causal_attention_mask (
283+ num_frames , height * width , hidden_states .dtype , hidden_states .device , batch_size = batch_size
284+ )
285+ hidden_states = attn (hidden_states , attention_mask = attention_mask )
286+ hidden_states = hidden_states .unflatten (1 , (num_frames , height , width )).permute (0 , 4 , 1 , 2 , 3 )
287+
288+ hidden_states = resnet (hidden_states )
255289
256290 return hidden_states
257291
@@ -303,8 +337,26 @@ def __init__(
303337 self .downsamplers = None
304338
305339 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
306- for resnet in self .resnets :
307- hidden_states = resnet (hidden_states )
340+ if torch .is_grad_enabled () and self .gradient_checkpointing :
341+
342+ def create_custom_forward (module , return_dict = None ):
343+ def custom_forward (* inputs ):
344+ if return_dict is not None :
345+ return module (* inputs , return_dict = return_dict )
346+ else :
347+ return module (* inputs )
348+
349+ return custom_forward
350+
351+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
352+
353+ for resnet in self .resnets :
354+ hidden_states = torch .utils .checkpoint .checkpoint (
355+ create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs
356+ )
357+ else :
358+ for resnet in self .resnets :
359+ hidden_states = resnet (hidden_states )
308360
309361 if self .downsamplers is not None :
310362 for downsampler in self .downsamplers :
@@ -359,8 +411,27 @@ def __init__(
359411 self .upsamplers = None
360412
361413 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
362- for resnet in self .resnets :
363- hidden_states = resnet (hidden_states )
414+ if torch .is_grad_enabled () and self .gradient_checkpointing :
415+
416+ def create_custom_forward (module , return_dict = None ):
417+ def custom_forward (* inputs ):
418+ if return_dict is not None :
419+ return module (* inputs , return_dict = return_dict )
420+ else :
421+ return module (* inputs )
422+
423+ return custom_forward
424+
425+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
426+
427+ for resnet in self .resnets :
428+ hidden_states = torch .utils .checkpoint .checkpoint (
429+ create_custom_forward (resnet ), hidden_states , ** ckpt_kwargs
430+ )
431+
432+ else :
433+ for resnet in self .resnets :
434+ hidden_states = resnet (hidden_states )
364435
365436 if self .upsamplers is not None :
366437 for upsampler in self .upsamplers :
@@ -401,6 +472,9 @@ def __init__(
401472
402473 output_channel = block_out_channels [0 ]
403474 for i , down_block_type in enumerate (down_block_types ):
475+ if down_block_type != "HunyuanVideoDownBlock3D" :
476+ raise ValueError (f"Unsupported down_block_type: { down_block_type } " )
477+
404478 input_channel = output_channel
405479 output_channel = block_out_channels [i ]
406480 is_final_block = i == len (block_out_channels ) - 1
@@ -454,27 +528,35 @@ def __init__(
454528 self .gradient_checkpointing = False
455529
456530 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
457- # 1. Input layer
458531 hidden_states = self .conv_in (hidden_states )
459532
460- use_reentrant = is_torch_version ("<=" , "1.11.0" )
533+ if torch .is_grad_enabled () and self .gradient_checkpointing :
534+
535+ def create_custom_forward (module , return_dict = None ):
536+ def custom_forward (* inputs ):
537+ if return_dict is not None :
538+ return module (* inputs , return_dict = return_dict )
539+ else :
540+ return module (* inputs )
541+
542+ return custom_forward
543+
544+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
461545
462- def create_block_forward (block ):
463- if torch .is_grad_enabled () and self .gradient_checkpointing :
464- return lambda * inputs : torch .utils .checkpoint .checkpoint (
465- lambda * x : block (* x ), * inputs , use_reentrant = use_reentrant
546+ for down_block in self .down_blocks :
547+ hidden_states = torch .utils .checkpoint .checkpoint (
548+ create_custom_forward (down_block ), hidden_states , ** ckpt_kwargs
466549 )
467- else :
468- return block
469550
470- # 2. Down blocks
471- for down_block in self .down_blocks :
472- hidden_states = create_block_forward (down_block )(hidden_states )
551+ hidden_states = torch .utils .checkpoint .checkpoint (
552+ create_custom_forward (self .mid_block ), hidden_states , ** ckpt_kwargs
553+ )
554+ else :
555+ for down_block in self .down_blocks :
556+ hidden_states = down_block (hidden_states )
473557
474- # 3. Mid block
475- hidden_states = self .mid_block (hidden_states )
558+ hidden_states = self .mid_block (hidden_states )
476559
477- # 4. Output layers
478560 hidden_states = self .conv_norm_out (hidden_states )
479561 hidden_states = self .conv_act (hidden_states )
480562 hidden_states = self .conv_out (hidden_states )
@@ -501,7 +583,6 @@ def __init__(
501583 layers_per_block : int = 2 ,
502584 norm_num_groups : int = 32 ,
503585 act_fn : str = "silu" ,
504- norm_type : str = "group" ,
505586 mid_block_add_attention = True ,
506587 time_compression_ratio : int = 4 ,
507588 spatial_compression_ratio : int = 8 ,
@@ -527,6 +608,9 @@ def __init__(
527608 reversed_block_out_channels = list (reversed (block_out_channels ))
528609 output_channel = reversed_block_out_channels [0 ]
529610 for i , up_block_type in enumerate (up_block_types ):
611+ if up_block_type != "HunyuanVideoUpBlock3D" :
612+ raise ValueError (f"Unsupported up_block_type: { up_block_type } " )
613+
530614 prev_output_channel = output_channel
531615 output_channel = reversed_block_out_channels [i ]
532616 is_final_block = i == len (block_out_channels ) - 1
@@ -569,36 +653,30 @@ def __init__(
569653 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
570654 hidden_states = self .conv_in (hidden_states )
571655
572- upscale_dtype = next (iter (self .up_blocks .parameters ())).dtype
573- if self .training and self .gradient_checkpointing :
656+ if torch .is_grad_enabled () and self .gradient_checkpointing :
574657
575- def create_custom_forward (module ):
658+ def create_custom_forward (module , return_dict = None ):
576659 def custom_forward (* inputs ):
577- return module (* inputs )
660+ if return_dict is not None :
661+ return module (* inputs , return_dict = return_dict )
662+ else :
663+ return module (* inputs )
578664
579665 return custom_forward
580666
581- # up
667+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
668+
669+ hidden_states = torch .utils .checkpoint .checkpoint (
670+ create_custom_forward (self .mid_block ), hidden_states , ** ckpt_kwargs
671+ )
672+
582673 for up_block in self .up_blocks :
583674 hidden_states = torch .utils .checkpoint .checkpoint (
584- create_custom_forward (up_block ),
585- hidden_states ,
586- use_reentrant = False ,
675+ create_custom_forward (up_block ), hidden_states , ** ckpt_kwargs
587676 )
588- else :
589- # middle
590- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), hidden_states )
591- hidden_states = hidden_states .to (upscale_dtype )
592-
593- # up
594- for up_block in self .up_blocks :
595- hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), hidden_states )
596677 else :
597- # middle
598678 hidden_states = self .mid_block (hidden_states )
599- hidden_states = hidden_states .to (upscale_dtype )
600679
601- # up
602680 for up_block in self .up_blocks :
603681 hidden_states = up_block (hidden_states )
604682
@@ -643,7 +721,6 @@ def __init__(
643721 layers_per_block : int = 2 ,
644722 act_fn : str = "silu" ,
645723 norm_num_groups : int = 32 ,
646- sample_tsize : int = 64 ,
647724 scaling_factor : float = 0.476986 ,
648725 spatial_compression_ratio : int = 8 ,
649726 temporal_compression_ratio : int = 4 ,
@@ -700,11 +777,10 @@ def __init__(
700777 self .use_framewise_encoding = True
701778 self .use_framewise_decoding = True
702779
703-
704780 # The minimal tile height and width for spatial tiling to be used
705781 self .tile_sample_min_height = 256
706782 self .tile_sample_min_width = 256
707-
783+
708784 # The minimal tile temporal batch size for temporal tiling to be used
709785 self .tile_sample_min_tsize = 64
710786
0 commit comments