1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from dataclasses import dataclass
1615from typing import Optional , Tuple , Union
1716
1817import numpy as np
1918import torch
2019import torch .nn as nn
2120import torch .nn .functional as F
22- from einops import rearrange
2321
2422from ...configuration_utils import ConfigMixin , register_to_config
2523from ...utils import is_torch_version , logging
2826from ..attention_processor import Attention
2927from ..modeling_outputs import AutoencoderKLOutput
3028from ..modeling_utils import ModelMixin
31- from .vae import BaseOutput , DecoderOutput , DiagonalGaussianDistribution
29+ from .vae import DecoderOutput , DiagonalGaussianDistribution
3230
3331
3432logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3533
3634
37- def prepare_causal_attention_mask (n_frame : int , n_hw : int , dtype , device , batch_size : int = None ):
38- seq_len = n_frame * n_hw
35+ def prepare_causal_attention_mask (
36+ num_frames : int , height_width : int , dtype : torch .dtype , device : torch .device , batch_size : int = None
37+ ):
38+ seq_len = num_frames * height_width
3939 mask = torch .full ((seq_len , seq_len ), float ("-inf" ), dtype = dtype , device = device )
4040 for i in range (seq_len ):
41- i_frame = i // n_hw
42- mask [i , : (i_frame + 1 ) * n_hw ] = 0
41+ i_frame = i // height_width
42+ mask [i , : (i_frame + 1 ) * height_width ] = 0
4343 if batch_size is not None :
4444 mask = mask .unsqueeze (0 ).expand (batch_size , - 1 , - 1 )
4545 return mask
@@ -178,7 +178,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
178178 return hidden_states
179179
180180
181- class UNetMidBlockCausal3D (nn .Module ):
181+ class HunyuanVideoMidBlock3D (nn .Module ):
182182 def __init__ (
183183 self ,
184184 in_channels : int ,
@@ -243,19 +243,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
243243 hidden_states = self .resnets [0 ](hidden_states )
244244 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
245245 if attn is not None :
246- B , C , T , H , W = hidden_states .shape
247- hidden_states = rearrange ( hidden_states , "b c f h w -> b (f h w) c" )
246+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
247+ hidden_states = hidden_states . permute ( 0 , 2 , 3 , 4 , 1 ). flatten ( 1 , 3 )
248248 attention_mask = prepare_causal_attention_mask (
249- T , H * W , hidden_states .dtype , hidden_states .device , batch_size = B
249+ num_frames , height * width , hidden_states .dtype , hidden_states .device , batch_size = batch_size
250250 )
251251 hidden_states = attn (hidden_states , attention_mask = attention_mask )
252- hidden_states = rearrange (hidden_states , "b (f h w) c -> b c f h w" , f = T , h = H , w = W )
252+ hidden_states = hidden_states .unflatten (1 , (num_frames , height , width )).permute (0 , 4 , 1 , 2 , 3 )
253+
253254 hidden_states = resnet (hidden_states )
254255
255256 return hidden_states
256257
257258
258- class DownEncoderBlockCausal3D (nn .Module ):
259+ class HunyuanVideoDownBlock3D (nn .Module ):
259260 def __init__ (
260261 self ,
261262 in_channels : int ,
@@ -268,7 +269,7 @@ def __init__(
268269 add_downsample : bool = True ,
269270 downsample_stride : int = 2 ,
270271 downsample_padding : int = 1 ,
271- ):
272+ ) -> None :
272273 super ().__init__ ()
273274 resnets = []
274275
@@ -312,20 +313,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
312313 return hidden_states
313314
314315
315- class UpDecoderBlockCausal3D (nn .Module ):
316+ class HunyuanVideoUpBlock3D (nn .Module ):
316317 def __init__ (
317318 self ,
318319 in_channels : int ,
319320 out_channels : int ,
320- resolution_idx : Optional [int ] = None ,
321321 dropout : float = 0.0 ,
322322 num_layers : int = 1 ,
323323 resnet_eps : float = 1e-6 ,
324324 resnet_act_fn : str = "swish" ,
325325 resnet_groups : int = 32 ,
326326 add_upsample : bool = True ,
327- upsample_scale_factor = (2 , 2 , 2 ),
328- ):
327+ upsample_scale_factor : Tuple [ int , int , int ] = (2 , 2 , 2 ),
328+ ) -> None :
329329 super ().__init__ ()
330330 resnets = []
331331
@@ -358,8 +358,6 @@ def __init__(
358358 else :
359359 self .upsamplers = None
360360
361- self .resolution_idx = resolution_idx
362-
363361 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
364362 for resnet in self .resnets :
365363 hidden_states = resnet (hidden_states )
@@ -381,10 +379,10 @@ def __init__(
381379 in_channels : int = 3 ,
382380 out_channels : int = 3 ,
383381 down_block_types : Tuple [str , ...] = (
384- "DownEncoderBlockCausal3D " ,
385- "DownEncoderBlockCausal3D " ,
386- "DownEncoderBlockCausal3D " ,
387- "DownEncoderBlockCausal3D " ,
382+ "HunyuanVideoDownBlock3D " ,
383+ "HunyuanVideoDownBlock3D " ,
384+ "HunyuanVideoDownBlock3D " ,
385+ "HunyuanVideoDownBlock3D " ,
388386 ),
389387 block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 ),
390388 layers_per_block : int = 2 ,
@@ -424,7 +422,7 @@ def __init__(
424422 downsample_stride_T = (2 ,) if add_time_downsample else (1 ,)
425423 downsample_stride = tuple (downsample_stride_T + downsample_stride_HW )
426424
427- down_block = DownEncoderBlockCausal3D (
425+ down_block = HunyuanVideoDownBlock3D (
428426 num_layers = layers_per_block ,
429427 in_channels = input_channel ,
430428 out_channels = output_channel ,
@@ -438,7 +436,7 @@ def __init__(
438436
439437 self .down_blocks .append (down_block )
440438
441- self .mid_block = UNetMidBlockCausal3D (
439+ self .mid_block = HunyuanVideoMidBlock3D (
442440 in_channels = block_out_channels [- 1 ],
443441 resnet_eps = 1e-6 ,
444442 resnet_act_fn = act_fn ,
@@ -494,10 +492,10 @@ def __init__(
494492 in_channels : int = 3 ,
495493 out_channels : int = 3 ,
496494 up_block_types : Tuple [str , ...] = (
497- "UpDecoderBlockCausal3D " ,
498- "UpDecoderBlockCausal3D " ,
499- "UpDecoderBlockCausal3D " ,
500- "UpDecoderBlockCausal3D " ,
495+ "HunyuanVideoUpBlock3D " ,
496+ "HunyuanVideoUpBlock3D " ,
497+ "HunyuanVideoUpBlock3D " ,
498+ "HunyuanVideoUpBlock3D " ,
501499 ),
502500 block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 ),
503501 layers_per_block : int = 2 ,
@@ -516,7 +514,7 @@ def __init__(
516514 self .up_blocks = nn .ModuleList ([])
517515
518516 # mid
519- self .mid_block = UNetMidBlockCausal3D (
517+ self .mid_block = HunyuanVideoMidBlock3D (
520518 in_channels = block_out_channels [- 1 ],
521519 resnet_eps = 1e-6 ,
522520 resnet_act_fn = act_fn ,
@@ -547,7 +545,7 @@ def __init__(
547545 upsample_scale_factor_T = (2 ,) if add_time_upsample else (1 ,)
548546 upsample_scale_factor = tuple (upsample_scale_factor_T + upsample_scale_factor_HW )
549547
550- up_block = UpDecoderBlockCausal3D (
548+ up_block = HunyuanVideoUpBlock3D (
551549 num_layers = self .layers_per_block + 1 ,
552550 in_channels = prev_output_channel ,
553551 out_channels = output_channel ,
@@ -568,10 +566,8 @@ def __init__(
568566
569567 self .gradient_checkpointing = False
570568
571- def forward (self , sample : torch .Tensor ) -> torch .Tensor :
572- assert len (sample .shape ) == 5 , "The input tensor should have 5 dimensions"
573-
574- sample = self .conv_in (sample )
569+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
570+ hidden_states = self .conv_in (hidden_states )
575571
576572 upscale_dtype = next (iter (self .up_blocks .parameters ())).dtype
577573 if self .training and self .gradient_checkpointing :
@@ -584,40 +580,34 @@ def custom_forward(*inputs):
584580
585581 # up
586582 for up_block in self .up_blocks :
587- sample = torch .utils .checkpoint .checkpoint (
583+ hidden_states = torch .utils .checkpoint .checkpoint (
588584 create_custom_forward (up_block ),
589- sample ,
585+ hidden_states ,
590586 use_reentrant = False ,
591587 )
592588 else :
593589 # middle
594- sample = torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), sample )
595- sample = sample .to (upscale_dtype )
590+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), hidden_states )
591+ hidden_states = hidden_states .to (upscale_dtype )
596592
597593 # up
598594 for up_block in self .up_blocks :
599- sample = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), sample )
595+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), hidden_states )
600596 else :
601597 # middle
602- sample = self .mid_block (sample )
603- sample = sample .to (upscale_dtype )
598+ hidden_states = self .mid_block (hidden_states )
599+ hidden_states = hidden_states .to (upscale_dtype )
604600
605601 # up
606602 for up_block in self .up_blocks :
607- sample = up_block (sample )
603+ hidden_states = up_block (hidden_states )
608604
609605 # post-process
610- sample = self .conv_norm_out (sample )
611- sample = self .conv_act (sample )
612- sample = self .conv_out (sample )
613-
614- return sample
615-
606+ hidden_states = self .conv_norm_out (hidden_states )
607+ hidden_states = self .conv_act (hidden_states )
608+ hidden_states = self .conv_out (hidden_states )
616609
617- @dataclass
618- class DecoderOutput2 (BaseOutput ):
619- sample : torch .Tensor
620- posterior : Optional [DiagonalGaussianDistribution ] = None
610+ return hidden_states
621611
622612
623613class AutoencoderKLHunyuanVideo (ModelMixin , ConfigMixin ):
@@ -638,16 +628,16 @@ def __init__(
638628 out_channels : int = 3 ,
639629 latent_channels : int = 16 ,
640630 down_block_types : Tuple [str , ...] = (
641- "DownEncoderBlockCausal3D " ,
642- "DownEncoderBlockCausal3D " ,
643- "DownEncoderBlockCausal3D " ,
644- "DownEncoderBlockCausal3D " ,
631+ "HunyuanVideoDownBlock3D " ,
632+ "HunyuanVideoDownBlock3D " ,
633+ "HunyuanVideoDownBlock3D " ,
634+ "HunyuanVideoDownBlock3D " ,
645635 ),
646636 up_block_types : Tuple [str , ...] = (
647- "UpDecoderBlockCausal3D " ,
648- "UpDecoderBlockCausal3D " ,
649- "UpDecoderBlockCausal3D " ,
650- "UpDecoderBlockCausal3D " ,
637+ "HunyuanVideoUpBlock3D " ,
638+ "HunyuanVideoUpBlock3D " ,
639+ "HunyuanVideoUpBlock3D " ,
640+ "HunyuanVideoUpBlock3D " ,
651641 ),
652642 block_out_channels : Tuple [int ] = (128 , 256 , 512 , 512 ),
653643 layers_per_block : int = 2 ,
@@ -1050,9 +1040,8 @@ def forward(
10501040 sample : torch .Tensor ,
10511041 sample_posterior : bool = False ,
10521042 return_dict : bool = True ,
1053- return_posterior : bool = False ,
10541043 generator : Optional [torch .Generator ] = None ,
1055- ) -> Union [DecoderOutput2 , torch .Tensor ]:
1044+ ) -> Union [DecoderOutput , torch .Tensor ]:
10561045 r"""
10571046 Args:
10581047 sample (`torch.Tensor`): Input sample.
@@ -1067,14 +1056,7 @@ def forward(
10671056 z = posterior .sample (generator = generator )
10681057 else :
10691058 z = posterior .mode ()
1070- dec = self .decode (z ).sample
1071-
1059+ dec = self .decode (z )
10721060 if not return_dict :
1073- if return_posterior :
1074- return (dec , posterior )
1075- else :
1076- return (dec ,)
1077- if return_posterior :
1078- return DecoderOutput2 (sample = dec , posterior = posterior )
1079- else :
1080- return DecoderOutput2 (sample = dec )
1061+ return (dec ,)
1062+ return dec
0 commit comments