Skip to content

Commit d0036ff

Browse files
committed
refactor
1 parent d6c16ef commit d0036ff

File tree

1 file changed

+56
-74
lines changed

1 file changed

+56
-74
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass
1615
from typing import Optional, Tuple, Union
1716

1817
import numpy as np
1918
import torch
2019
import torch.nn as nn
2120
import torch.nn.functional as F
22-
from einops import rearrange
2321

2422
from ...configuration_utils import ConfigMixin, register_to_config
2523
from ...utils import is_torch_version, logging
@@ -28,18 +26,20 @@
2826
from ..attention_processor import Attention
2927
from ..modeling_outputs import AutoencoderKLOutput
3028
from ..modeling_utils import ModelMixin
31-
from .vae import BaseOutput, DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import DecoderOutput, DiagonalGaussianDistribution
3230

3331

3432
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3533

3634

37-
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38-
seq_len = n_frame * n_hw
35+
def prepare_causal_attention_mask(
36+
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
37+
):
38+
seq_len = num_frames * height_width
3939
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
4040
for i in range(seq_len):
41-
i_frame = i // n_hw
42-
mask[i, : (i_frame + 1) * n_hw] = 0
41+
i_frame = i // height_width
42+
mask[i, : (i_frame + 1) * height_width] = 0
4343
if batch_size is not None:
4444
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
4545
return mask
@@ -178,7 +178,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
178178
return hidden_states
179179

180180

181-
class UNetMidBlockCausal3D(nn.Module):
181+
class HunyuanVideoMidBlock3D(nn.Module):
182182
def __init__(
183183
self,
184184
in_channels: int,
@@ -243,19 +243,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
243243
hidden_states = self.resnets[0](hidden_states)
244244
for attn, resnet in zip(self.attentions, self.resnets[1:]):
245245
if attn is not None:
246-
B, C, T, H, W = hidden_states.shape
247-
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
246+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
247+
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
248248
attention_mask = prepare_causal_attention_mask(
249-
T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
249+
num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
250250
)
251251
hidden_states = attn(hidden_states, attention_mask=attention_mask)
252-
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
252+
hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
253+
253254
hidden_states = resnet(hidden_states)
254255

255256
return hidden_states
256257

257258

258-
class DownEncoderBlockCausal3D(nn.Module):
259+
class HunyuanVideoDownBlock3D(nn.Module):
259260
def __init__(
260261
self,
261262
in_channels: int,
@@ -268,7 +269,7 @@ def __init__(
268269
add_downsample: bool = True,
269270
downsample_stride: int = 2,
270271
downsample_padding: int = 1,
271-
):
272+
) -> None:
272273
super().__init__()
273274
resnets = []
274275

@@ -312,20 +313,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
312313
return hidden_states
313314

314315

315-
class UpDecoderBlockCausal3D(nn.Module):
316+
class HunyuanVideoUpBlock3D(nn.Module):
316317
def __init__(
317318
self,
318319
in_channels: int,
319320
out_channels: int,
320-
resolution_idx: Optional[int] = None,
321321
dropout: float = 0.0,
322322
num_layers: int = 1,
323323
resnet_eps: float = 1e-6,
324324
resnet_act_fn: str = "swish",
325325
resnet_groups: int = 32,
326326
add_upsample: bool = True,
327-
upsample_scale_factor=(2, 2, 2),
328-
):
327+
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
328+
) -> None:
329329
super().__init__()
330330
resnets = []
331331

@@ -358,8 +358,6 @@ def __init__(
358358
else:
359359
self.upsamplers = None
360360

361-
self.resolution_idx = resolution_idx
362-
363361
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
364362
for resnet in self.resnets:
365363
hidden_states = resnet(hidden_states)
@@ -381,10 +379,10 @@ def __init__(
381379
in_channels: int = 3,
382380
out_channels: int = 3,
383381
down_block_types: Tuple[str, ...] = (
384-
"DownEncoderBlockCausal3D",
385-
"DownEncoderBlockCausal3D",
386-
"DownEncoderBlockCausal3D",
387-
"DownEncoderBlockCausal3D",
382+
"HunyuanVideoDownBlock3D",
383+
"HunyuanVideoDownBlock3D",
384+
"HunyuanVideoDownBlock3D",
385+
"HunyuanVideoDownBlock3D",
388386
),
389387
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
390388
layers_per_block: int = 2,
@@ -424,7 +422,7 @@ def __init__(
424422
downsample_stride_T = (2,) if add_time_downsample else (1,)
425423
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
426424

427-
down_block = DownEncoderBlockCausal3D(
425+
down_block = HunyuanVideoDownBlock3D(
428426
num_layers=layers_per_block,
429427
in_channels=input_channel,
430428
out_channels=output_channel,
@@ -438,7 +436,7 @@ def __init__(
438436

439437
self.down_blocks.append(down_block)
440438

441-
self.mid_block = UNetMidBlockCausal3D(
439+
self.mid_block = HunyuanVideoMidBlock3D(
442440
in_channels=block_out_channels[-1],
443441
resnet_eps=1e-6,
444442
resnet_act_fn=act_fn,
@@ -494,10 +492,10 @@ def __init__(
494492
in_channels: int = 3,
495493
out_channels: int = 3,
496494
up_block_types: Tuple[str, ...] = (
497-
"UpDecoderBlockCausal3D",
498-
"UpDecoderBlockCausal3D",
499-
"UpDecoderBlockCausal3D",
500-
"UpDecoderBlockCausal3D",
495+
"HunyuanVideoUpBlock3D",
496+
"HunyuanVideoUpBlock3D",
497+
"HunyuanVideoUpBlock3D",
498+
"HunyuanVideoUpBlock3D",
501499
),
502500
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
503501
layers_per_block: int = 2,
@@ -516,7 +514,7 @@ def __init__(
516514
self.up_blocks = nn.ModuleList([])
517515

518516
# mid
519-
self.mid_block = UNetMidBlockCausal3D(
517+
self.mid_block = HunyuanVideoMidBlock3D(
520518
in_channels=block_out_channels[-1],
521519
resnet_eps=1e-6,
522520
resnet_act_fn=act_fn,
@@ -547,7 +545,7 @@ def __init__(
547545
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
548546
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
549547

550-
up_block = UpDecoderBlockCausal3D(
548+
up_block = HunyuanVideoUpBlock3D(
551549
num_layers=self.layers_per_block + 1,
552550
in_channels=prev_output_channel,
553551
out_channels=output_channel,
@@ -568,10 +566,8 @@ def __init__(
568566

569567
self.gradient_checkpointing = False
570568

571-
def forward(self, sample: torch.Tensor) -> torch.Tensor:
572-
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
573-
574-
sample = self.conv_in(sample)
569+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
570+
hidden_states = self.conv_in(hidden_states)
575571

576572
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
577573
if self.training and self.gradient_checkpointing:
@@ -584,40 +580,34 @@ def custom_forward(*inputs):
584580

585581
# up
586582
for up_block in self.up_blocks:
587-
sample = torch.utils.checkpoint.checkpoint(
583+
hidden_states = torch.utils.checkpoint.checkpoint(
588584
create_custom_forward(up_block),
589-
sample,
585+
hidden_states,
590586
use_reentrant=False,
591587
)
592588
else:
593589
# middle
594-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
595-
sample = sample.to(upscale_dtype)
590+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
591+
hidden_states = hidden_states.to(upscale_dtype)
596592

597593
# up
598594
for up_block in self.up_blocks:
599-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
595+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
600596
else:
601597
# middle
602-
sample = self.mid_block(sample)
603-
sample = sample.to(upscale_dtype)
598+
hidden_states = self.mid_block(hidden_states)
599+
hidden_states = hidden_states.to(upscale_dtype)
604600

605601
# up
606602
for up_block in self.up_blocks:
607-
sample = up_block(sample)
603+
hidden_states = up_block(hidden_states)
608604

609605
# post-process
610-
sample = self.conv_norm_out(sample)
611-
sample = self.conv_act(sample)
612-
sample = self.conv_out(sample)
613-
614-
return sample
615-
606+
hidden_states = self.conv_norm_out(hidden_states)
607+
hidden_states = self.conv_act(hidden_states)
608+
hidden_states = self.conv_out(hidden_states)
616609

617-
@dataclass
618-
class DecoderOutput2(BaseOutput):
619-
sample: torch.Tensor
620-
posterior: Optional[DiagonalGaussianDistribution] = None
610+
return hidden_states
621611

622612

623613
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
@@ -638,16 +628,16 @@ def __init__(
638628
out_channels: int = 3,
639629
latent_channels: int = 16,
640630
down_block_types: Tuple[str, ...] = (
641-
"DownEncoderBlockCausal3D",
642-
"DownEncoderBlockCausal3D",
643-
"DownEncoderBlockCausal3D",
644-
"DownEncoderBlockCausal3D",
631+
"HunyuanVideoDownBlock3D",
632+
"HunyuanVideoDownBlock3D",
633+
"HunyuanVideoDownBlock3D",
634+
"HunyuanVideoDownBlock3D",
645635
),
646636
up_block_types: Tuple[str, ...] = (
647-
"UpDecoderBlockCausal3D",
648-
"UpDecoderBlockCausal3D",
649-
"UpDecoderBlockCausal3D",
650-
"UpDecoderBlockCausal3D",
637+
"HunyuanVideoUpBlock3D",
638+
"HunyuanVideoUpBlock3D",
639+
"HunyuanVideoUpBlock3D",
640+
"HunyuanVideoUpBlock3D",
651641
),
652642
block_out_channels: Tuple[int] = (128, 256, 512, 512),
653643
layers_per_block: int = 2,
@@ -1050,9 +1040,8 @@ def forward(
10501040
sample: torch.Tensor,
10511041
sample_posterior: bool = False,
10521042
return_dict: bool = True,
1053-
return_posterior: bool = False,
10541043
generator: Optional[torch.Generator] = None,
1055-
) -> Union[DecoderOutput2, torch.Tensor]:
1044+
) -> Union[DecoderOutput, torch.Tensor]:
10561045
r"""
10571046
Args:
10581047
sample (`torch.Tensor`): Input sample.
@@ -1067,14 +1056,7 @@ def forward(
10671056
z = posterior.sample(generator=generator)
10681057
else:
10691058
z = posterior.mode()
1070-
dec = self.decode(z).sample
1071-
1059+
dec = self.decode(z)
10721060
if not return_dict:
1073-
if return_posterior:
1074-
return (dec, posterior)
1075-
else:
1076-
return (dec,)
1077-
if return_posterior:
1078-
return DecoderOutput2(sample=dec, posterior=posterior)
1079-
else:
1080-
return DecoderOutput2(sample=dec)
1061+
return (dec,)
1062+
return dec

0 commit comments

Comments
 (0)