2929from ..modeling_utils import ModelMixin
3030from .vae import DecoderOutput , DiagonalGaussianDistribution
3131
32- #YiYi TODO: remove this
33- from einops import rearrange
34-
3532
3633logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3734
@@ -500,7 +497,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
500497
501498 hidden_states = self .mid_block (hidden_states )
502499
503- short_cut = rearrange (hidden_states , "b (c r) f h w -> b c r f h w" , r = self .group_size ).mean (dim = 2 )
500+ # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
501+ batch_size , _ , frame , height , width = hidden_states .shape
502+ short_cut = hidden_states .view (batch_size , - 1 , self .group_size , frame , height , width ).mean (dim = 2 )
504503
505504 hidden_states = self .norm_out (hidden_states )
506505 hidden_states = self .conv_act (hidden_states )
@@ -513,7 +512,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
513512
514513class HunyuanImageRefinerDecoder3D (nn .Module ):
515514 r"""
516- Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603) .
515+ Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner .
517516 """
518517
519518 def __init__ (
@@ -600,7 +599,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
600599class AutoencoderKLHunyuanImageRefiner (ModelMixin , ConfigMixin ):
601600 r"""
602601 A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
603- Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603) .
602+ Used for HunyuanImage-2.1 Refiner. .
604603
605604 This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
606605 for all models (such as downloading or saving).
0 commit comments