2424from ...utils import logging
2525from ...utils .accelerate_utils import apply_forward_hook
2626from ..activations import get_activation
27- from ..attention_processor import Attention
2827from ..modeling_outputs import AutoencoderKLOutput
2928from ..modeling_utils import ModelMixin
3029from .vae import DecoderOutput , DiagonalGaussianDistribution
@@ -126,8 +125,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
126125
127126 x = x .squeeze (1 ).reshape (batch_size , frames , height , width , channels ).permute (0 , 4 , 1 , 2 , 3 )
128127 x = self .proj_out (x )
129-
130- return x + identity
128+
129+ return x + identity
131130
132131
133132class HunyuanImageRefinerUpsampleDCAE (nn .Module ):
@@ -143,11 +142,11 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: b
143142 def _dcae_upsample_rearrange (tensor , r1 = 1 , r2 = 2 , r3 = 2 ):
144143 """
145144 Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
146-
145+
147146 Args:
148147 tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
149148 r1: temporal upsampling factor
150- r2: height upsampling factor
149+ r2: height upsampling factor
151150 r3: width upsampling factor
152151 """
153152 b , packed_c , f , h , w = tensor .shape
@@ -187,12 +186,11 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample:
187186 self .add_temporal_downsample = add_temporal_downsample
188187 self .group_size = factor * in_channels // out_channels
189188
190-
191189 @staticmethod
192190 def _dcae_downsample_rearrange (self , tensor , r1 = 1 , r2 = 2 , r3 = 2 ):
193191 """
194192 Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
195-
193+
196194 This packs spatial/temporal dimensions into channels (opposite of upsample)
197195 """
198196 b , c , packed_f , packed_h , packed_w = tensor .shape
@@ -202,7 +200,6 @@ def _dcae_downsample_rearrange(self, tensor, r1=1, r2=2, r3=2):
202200 tensor = tensor .permute (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 )
203201 return tensor .reshape (b , r1 * r2 * r3 * c , f , h , w )
204202
205-
206203 def forward (self , x : torch .Tensor ):
207204 r1 = 2 if self .add_temporal_downsample else 1
208205 h = self .conv (x )
@@ -304,16 +301,13 @@ def __init__(
304301 self .gradient_checkpointing = False
305302
306303 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
307-
308304 hidden_states = self .resnets [0 ](hidden_states )
309305
310-
311306 for attn , resnet in zip (self .attentions , self .resnets [1 :]):
312307 if attn is not None :
313308 hidden_states = attn (hidden_states )
314309 hidden_states = resnet (hidden_states )
315310
316-
317311 return hidden_states
318312
319313
@@ -356,7 +350,6 @@ def __init__(
356350 self .gradient_checkpointing = False
357351
358352 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
359-
360353 for resnet in self .resnets :
361354 hidden_states = resnet (hidden_states )
362355
@@ -461,7 +454,6 @@ def __init__(
461454 )
462455 input_channel = output_channel
463456 else :
464-
465457 add_temporal_downsample = i >= np .log2 (spatial_compression_ratio // temporal_compression_ratio )
466458 downsample_out_channels = block_out_channels [i + 1 ] if downsample_match_channel else output_channel
467459 down_block = HunyuanImageRefinerDownBlock3D (
@@ -518,7 +510,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
518510 def __init__ (
519511 self ,
520512 in_channels : int = 32 ,
521- out_channels : int = 3 ,
513+ out_channels : int = 3 ,
522514 block_out_channels : Tuple [int , ...] = (1024 , 1024 , 512 , 256 , 128 ),
523515 layers_per_block : int = 2 ,
524516 spatial_compression_ratio : int = 16 ,
@@ -574,10 +566,8 @@ def __init__(
574566 self .gradient_checkpointing = False
575567
576568 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
577-
578569 hidden_states = self .conv_in (hidden_states ) + hidden_states .repeat_interleave (repeats = self .repeat , dim = 1 )
579570
580-
581571 if torch .is_grad_enabled () and self .gradient_checkpointing :
582572 hidden_states = self ._gradient_checkpointing_func (self .mid_block , hidden_states )
583573
@@ -598,8 +588,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
598588
599589class AutoencoderKLHunyuanImageRefiner (ModelMixin , ConfigMixin ):
600590 r"""
601- A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
602- Used for HunyuanImage-2.1 Refiner..
591+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
592+ HunyuanImage-2.1 Refiner..
603593
604594 This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
605595 for all models (such as downloading or saving).
@@ -621,7 +611,7 @@ def __init__(
621611 upsample_match_channel : bool = True ,
622612 scaling_factor : float = 1.03682 ,
623613 ) -> None :
624- super ().__init__ ()
614+ super ().__init__ ()
625615
626616 self .encoder = HunyuanImageRefinerEncoder3D (
627617 in_channels = in_channels ,
@@ -655,7 +645,6 @@ def __init__(
655645 # intermediate tiles together, the memory requirement can be lowered.
656646 self .use_tiling = False
657647
658-
659648 # The minimal tile height and width for spatial tiling to be used
660649 self .tile_sample_min_height = 256
661650 self .tile_sample_min_width = 256
@@ -763,7 +752,7 @@ def _decode(self, z: torch.Tensor) -> torch.Tensor:
763752
764753 if self .use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height ):
765754 return self .tiled_decode (z )
766-
755+
767756 dec = self .decoder (z )
768757
769758 return dec
@@ -829,7 +818,7 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
829818 The latent representation of the encoded videos.
830819 """
831820 _ , _ , _ , height , width = x .shape
832-
821+
833822 tile_latent_min_height = self .tile_sample_min_height // self .spatial_compression_ratio
834823 tile_latent_min_width = self .tile_sample_min_width // self .spatial_compression_ratio
835824 overlap_height = int (tile_latent_min_height * (1 - self .tile_overlap_factor )) # 256 * (1 - 0.25) = 192
@@ -922,7 +911,6 @@ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
922911
923912 return dec
924913
925-
926914 def forward (
927915 self ,
928916 sample : torch .Tensor ,
0 commit comments