Skip to content

Commit 02864b5

Browse files
committed
remove einops
1 parent 419c99d commit 02864b5

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
from ..modeling_utils import ModelMixin
3030
from .vae import DecoderOutput, DiagonalGaussianDistribution
3131

32-
#YiYi TODO: remove this
33-
from einops import rearrange
34-
3532

3633
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3734

@@ -500,7 +497,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
500497

501498
hidden_states = self.mid_block(hidden_states)
502499

503-
short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
500+
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
501+
batch_size, _, frame, height, width = hidden_states.shape
502+
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
504503

505504
hidden_states = self.norm_out(hidden_states)
506505
hidden_states = self.conv_act(hidden_states)
@@ -513,7 +512,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
513512

514513
class HunyuanImageRefinerDecoder3D(nn.Module):
515514
r"""
516-
Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
515+
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
517516
"""
518517

519518
def __init__(
@@ -600,7 +599,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
600599
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
601600
r"""
602601
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
603-
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
602+
Used for HunyuanImage-2.1 Refiner..
604603
605604
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
606605
for all models (such as downloading or saving).

0 commit comments

Comments
 (0)