2222from einops import rearrange
2323
2424from ...configuration_utils import ConfigMixin , register_to_config
25- from ...utils import logging , is_torch_version
25+ from ...utils import is_torch_version , logging
2626from ...utils .accelerate_utils import apply_forward_hook
2727from ..activations import get_activation
2828from ..attention_processor import Attention , SpatialNorm
@@ -83,116 +83,35 @@ def forward(self, x):
8383
8484
8585class UpsampleCausal3D (nn .Module ):
86- """
87- A 3D upsampling layer with an optional convolution.
88- """
89-
9086 def __init__ (
9187 self ,
92- channels : int ,
93- use_conv : bool = False ,
94- use_conv_transpose : bool = False ,
88+ in_channels : int ,
9589 out_channels : Optional [int ] = None ,
96- name : str = "conv" ,
97- kernel_size : Optional [int ] = None ,
98- padding = 1 ,
99- norm_type = None ,
100- eps = None ,
101- elementwise_affine = None ,
102- bias = True ,
103- interpolate = True ,
104- upsample_factor = (2 , 2 , 2 ),
105- ):
90+ bias : bool = True ,
91+ upsample_factor : Tuple [float , float , float ] = (2 , 2 , 2 ),
92+ ) -> None :
10693 super ().__init__ ()
107- self .channels = channels
108- self .out_channels = out_channels or channels
109- self .use_conv = use_conv
110- self .use_conv_transpose = use_conv_transpose
111- self .name = name
112- self .interpolate = interpolate
113- self .upsample_factor = upsample_factor
11494
115- if norm_type == "ln_norm" :
116- self .norm = nn .LayerNorm (channels , eps , elementwise_affine )
117- elif norm_type == "rms_norm" :
118- self .norm = RMSNorm (channels , eps , elementwise_affine )
119- elif norm_type is None :
120- self .norm = None
121- else :
122- raise ValueError (f"unknown norm_type: { norm_type } " )
123-
124- conv = None
125- if use_conv_transpose :
126- assert False , "Not Implement yet"
127- if kernel_size is None :
128- kernel_size = 4
129- conv = nn .ConvTranspose2d (
130- channels , self .out_channels , kernel_size = kernel_size , stride = 2 , padding = padding , bias = bias
131- )
132- elif use_conv :
133- if kernel_size is None :
134- kernel_size = 3
135- conv = CausalConv3d (self .channels , self .out_channels , kernel_size = kernel_size , bias = bias )
136-
137- if name == "conv" :
138- self .conv = conv
139- else :
140- self .Conv2d_0 = conv
141-
142- def forward (
143- self ,
144- hidden_states : torch .Tensor ,
145- output_size : Optional [int ] = None ,
146- scale : float = 1.0 ,
147- ) -> torch .Tensor :
148- assert hidden_states .shape [1 ] == self .channels
95+ out_channels = out_channels or in_channels
96+ self .upsample_factor = upsample_factor
14997
150- if self .norm is not None :
151- assert False , "Not Implement yet"
152- hidden_states = self .norm (hidden_states .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
98+ self .conv = CausalConv3d (in_channels , out_channels , 3 , 1 , bias = bias )
15399
154- if self .use_conv_transpose :
155- return self .conv (hidden_states )
156-
157- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
158- dtype = hidden_states .dtype
159- if dtype == torch .bfloat16 :
160- hidden_states = hidden_states .to (torch .float32 )
161-
162- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
163- if hidden_states .shape [0 ] >= 64 :
164- hidden_states = hidden_states .contiguous ()
165-
166- # if `output_size` is passed we force the interpolation output
167- # size and do not make use of `scale_factor=2`
168- if self .interpolate :
169- B , C , T , H , W = hidden_states .shape
170- first_h , other_h = hidden_states .split ((1 , T - 1 ), dim = 2 )
171- if output_size is None :
172- if T > 1 :
173- other_h = F .interpolate (other_h , scale_factor = self .upsample_factor , mode = "nearest" )
174-
175- first_h = first_h .squeeze (2 )
176- first_h = F .interpolate (first_h , scale_factor = self .upsample_factor [1 :], mode = "nearest" )
177- first_h = first_h .unsqueeze (2 )
178- else :
179- assert False , "Not Implement yet"
180- other_h = F .interpolate (other_h , size = output_size , mode = "nearest" )
100+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
101+ num_frames = hidden_states .size (2 )
102+ first_frame , other_frames = hidden_states .split ((1 , num_frames - 1 ), dim = 2 )
181103
182- if T > 1 :
183- hidden_states = torch .cat ((first_h , other_h ), dim = 2 )
184- else :
185- hidden_states = first_h
104+ first_frame = F .interpolate (
105+ first_frame .squeeze (2 ), scale_factor = self .upsample_factor [1 :], mode = "nearest"
106+ ).unsqueeze (2 )
186107
187- # If the input is bfloat16, we cast back to bfloat16
188- if dtype == torch .bfloat16 :
189- hidden_states = hidden_states .to (dtype )
108+ if num_frames > 1 :
109+ other_frames = F .interpolate (other_frames , scale_factor = self .upsample_factor , mode = "nearest" )
110+ hidden_states = torch .cat ((first_frame , other_frames ), dim = 2 )
111+ else :
112+ hidden_states = first_frame
190113
191- if self .use_conv :
192- if self .name == "conv" :
193- hidden_states = self .conv (hidden_states )
194- else :
195- hidden_states = self .Conv2d_0 (hidden_states )
114+ hidden_states = self .conv (hidden_states )
196115
197116 return hidden_states
198117
@@ -278,13 +197,10 @@ def __init__(
278197 eps : float = 1e-6 ,
279198 non_linearity : str = "swish" ,
280199 skip_time_act : bool = False ,
281- # default, scale_shift, ada_group, spatial
282200 time_embedding_norm : str = "default" ,
283201 kernel : Optional [torch .Tensor ] = None ,
284202 output_scale_factor : float = 1.0 ,
285203 use_in_shortcut : Optional [bool ] = None ,
286- up : bool = False ,
287- down : bool = False ,
288204 conv_shortcut_bias : bool = True ,
289205 conv_3d_out_channels : Optional [int ] = None ,
290206 ):
@@ -295,8 +211,6 @@ def __init__(
295211 out_channels = in_channels if out_channels is None else out_channels
296212 self .out_channels = out_channels
297213 self .use_conv_shortcut = conv_shortcut
298- self .up = up
299- self .down = down
300214 self .output_scale_factor = output_scale_factor
301215 self .time_embedding_norm = time_embedding_norm
302216 self .skip_time_act = skip_time_act
@@ -340,12 +254,6 @@ def __init__(
340254
341255 self .nonlinearity = get_activation (non_linearity )
342256
343- self .upsample = self .downsample = None
344- if self .up :
345- self .upsample = UpsampleCausal3D (in_channels , use_conv = False )
346- elif self .down :
347- self .downsample = DownsampleCausal3D (in_channels , use_conv = False , name = "op" )
348-
349257 self .use_in_shortcut = self .in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
350258
351259 self .conv_shortcut = None
@@ -372,18 +280,6 @@ def forward(
372280 hidden_states = self .norm1 (hidden_states )
373281
374282 hidden_states = self .nonlinearity (hidden_states )
375-
376- if self .upsample is not None :
377- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
378- if hidden_states .shape [0 ] >= 64 :
379- input_tensor = input_tensor .contiguous ()
380- hidden_states = hidden_states .contiguous ()
381- input_tensor = self .upsample (input_tensor , scale = scale )
382- hidden_states = self .upsample (hidden_states , scale = scale )
383- elif self .downsample is not None :
384- input_tensor = self .downsample (input_tensor , scale = scale )
385- hidden_states = self .downsample (hidden_states , scale = scale )
386-
387283 hidden_states = self .conv1 (hidden_states )
388284
389285 if self .time_emb_proj is not None :
@@ -461,12 +357,6 @@ def __init__(
461357 ]
462358 attentions = []
463359
464- if attention_head_dim is None :
465- logger .warn (
466- f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: { in_channels } ."
467- )
468- attention_head_dim = in_channels
469-
470360 for _ in range (num_layers ):
471361 if self .add_attention :
472362 # assert False, "Not implemented yet"
@@ -634,7 +524,6 @@ def __init__(
634524 [
635525 UpsampleCausal3D (
636526 out_channels ,
637- use_conv = True ,
638527 out_channels = out_channels ,
639528 upsample_factor = upsample_scale_factor ,
640529 )
@@ -662,12 +551,17 @@ class EncoderCausal3D(nn.Module):
662551 r"""
663552 Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
664553 """
665-
554+
666555 def __init__ (
667556 self ,
668557 in_channels : int = 3 ,
669558 out_channels : int = 3 ,
670- down_block_types : Tuple [str , ...] = ("DownEncoderBlockCausal3D" , "DownEncoderBlockCausal3D" , "DownEncoderBlockCausal3D" , "DownEncoderBlockCausal3D" ),
559+ down_block_types : Tuple [str , ...] = (
560+ "DownEncoderBlockCausal3D" ,
561+ "DownEncoderBlockCausal3D" ,
562+ "DownEncoderBlockCausal3D" ,
563+ "DownEncoderBlockCausal3D" ,
564+ ),
671565 block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 ),
672566 layers_per_block : int = 2 ,
673567 norm_num_groups : int = 32 ,
@@ -678,7 +572,7 @@ def __init__(
678572 spatial_compression_ratio : int = 8 ,
679573 ) -> None :
680574 super ().__init__ ()
681-
575+
682576 self .conv_in = CausalConv3d (in_channels , block_out_channels [0 ], kernel_size = 3 , stride = 1 )
683577 self .mid_block = None
684578 self .down_blocks = nn .ModuleList ([])
@@ -717,7 +611,7 @@ def __init__(
717611 resnet_groups = norm_num_groups ,
718612 downsample_padding = 0 ,
719613 )
720-
614+
721615 self .down_blocks .append (down_block )
722616
723617 self .mid_block = UNetMidBlockCausal3D (
@@ -778,7 +672,12 @@ def __init__(
778672 self ,
779673 in_channels : int = 3 ,
780674 out_channels : int = 3 ,
781- up_block_types : Tuple [str , ...] = ("UpDecoderBlockCausal3D" , "UpDecoderBlockCausal3D" , "UpDecoderBlockCausal3D" , "UpDecoderBlockCausal3D" ),
675+ up_block_types : Tuple [str , ...] = (
676+ "UpDecoderBlockCausal3D" ,
677+ "UpDecoderBlockCausal3D" ,
678+ "UpDecoderBlockCausal3D" ,
679+ "UpDecoderBlockCausal3D" ,
680+ ),
782681 block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 512 ),
783682 layers_per_block : int = 2 ,
784683 norm_num_groups : int = 32 ,
@@ -831,7 +730,7 @@ def __init__(
831730 upsample_scale_factor_HW = (2 , 2 ) if add_spatial_upsample else (1 , 1 )
832731 upsample_scale_factor_T = (2 ,) if add_time_upsample else (1 ,)
833732 upsample_scale_factor = tuple (upsample_scale_factor_T + upsample_scale_factor_HW )
834-
733+
835734 up_block = UpDecoderBlockCausal3D (
836735 num_layers = self .layers_per_block + 1 ,
837736 in_channels = prev_output_channel ,
@@ -844,7 +743,7 @@ def __init__(
844743 resnet_time_scale_shift = norm_type ,
845744 temb_channels = temb_channels ,
846745 )
847-
746+
848747 self .up_blocks .append (up_block )
849748 prev_output_channel = output_channel
850749
@@ -923,8 +822,8 @@ class DecoderOutput2(BaseOutput):
923822
924823class AutoencoderKLHunyuanVideo (ModelMixin , ConfigMixin ):
925824 r"""
926- A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced
927- in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
825+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
826+ Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
928827
929828 This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
930829 for all models (such as downloading or saving).
@@ -1119,9 +1018,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11191018 return DecoderOutput (sample = dec )
11201019
11211020 @apply_forward_hook
1122- def decode (
1123- self , z : torch .Tensor , return_dict : bool = True
1124- ) -> Union [DecoderOutput , torch .Tensor ]:
1021+ def decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
11251022 """
11261023 Decode a batch of images/videos.
11271024
@@ -1229,9 +1126,7 @@ def spatial_tiled_encode(
12291126
12301127 return AutoencoderKLOutput (latent_dist = posterior )
12311128
1232- def spatial_tiled_decode (
1233- self , z : torch .Tensor , return_dict : bool = True
1234- ) -> Union [DecoderOutput , torch .Tensor ]:
1129+ def spatial_tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
12351130 r"""
12361131 Decode a batch of images/videos using a tiled decoder.
12371132
@@ -1315,9 +1210,7 @@ def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Au
13151210
13161211 return AutoencoderKLOutput (latent_dist = posterior )
13171212
1318- def temporal_tiled_decode (
1319- self , z : torch .Tensor , return_dict : bool = True
1320- ) -> Union [DecoderOutput , torch .Tensor ]:
1213+ def temporal_tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
13211214 # Split z into overlapping tiles and decode them separately.
13221215
13231216 B , C , T , H , W = z .shape
0 commit comments