Skip to content

Commit 1b27c3a

Browse files
committed
refactor upsample
1 parent 6915d62 commit 1b27c3a

File tree

2 files changed

+44
-153
lines changed

2 files changed

+44
-153
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ def remap_single_transformer_blocks_(key, state_dict):
109109
"single_blocks": remap_single_transformer_blocks_,
110110
}
111111

112-
VAE_KEYS_RENAME_DICT = {
113-
114-
}
112+
VAE_KEYS_RENAME_DICT = {}
115113

116114
VAE_SPECIAL_KEYS_REMAP = {}
117115

@@ -208,7 +206,7 @@ def get_args():
208206
transformer = transformer.to(dtype=dtype)
209207
if not args.save_pipeline:
210208
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
211-
209+
212210
if args.vae_ckpt_path is not None:
213211
vae = convert_vae(args.vae_ckpt_path)
214212
if not args.save_pipeline:

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 42 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from einops import rearrange
2323

2424
from ...configuration_utils import ConfigMixin, register_to_config
25-
from ...utils import logging, is_torch_version
25+
from ...utils import is_torch_version, logging
2626
from ...utils.accelerate_utils import apply_forward_hook
2727
from ..activations import get_activation
2828
from ..attention_processor import Attention, SpatialNorm
@@ -83,116 +83,35 @@ def forward(self, x):
8383

8484

8585
class UpsampleCausal3D(nn.Module):
86-
"""
87-
A 3D upsampling layer with an optional convolution.
88-
"""
89-
9086
def __init__(
9187
self,
92-
channels: int,
93-
use_conv: bool = False,
94-
use_conv_transpose: bool = False,
88+
in_channels: int,
9589
out_channels: Optional[int] = None,
96-
name: str = "conv",
97-
kernel_size: Optional[int] = None,
98-
padding=1,
99-
norm_type=None,
100-
eps=None,
101-
elementwise_affine=None,
102-
bias=True,
103-
interpolate=True,
104-
upsample_factor=(2, 2, 2),
105-
):
90+
bias: bool = True,
91+
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
92+
) -> None:
10693
super().__init__()
107-
self.channels = channels
108-
self.out_channels = out_channels or channels
109-
self.use_conv = use_conv
110-
self.use_conv_transpose = use_conv_transpose
111-
self.name = name
112-
self.interpolate = interpolate
113-
self.upsample_factor = upsample_factor
11494

115-
if norm_type == "ln_norm":
116-
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
117-
elif norm_type == "rms_norm":
118-
self.norm = RMSNorm(channels, eps, elementwise_affine)
119-
elif norm_type is None:
120-
self.norm = None
121-
else:
122-
raise ValueError(f"unknown norm_type: {norm_type}")
123-
124-
conv = None
125-
if use_conv_transpose:
126-
assert False, "Not Implement yet"
127-
if kernel_size is None:
128-
kernel_size = 4
129-
conv = nn.ConvTranspose2d(
130-
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
131-
)
132-
elif use_conv:
133-
if kernel_size is None:
134-
kernel_size = 3
135-
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
136-
137-
if name == "conv":
138-
self.conv = conv
139-
else:
140-
self.Conv2d_0 = conv
141-
142-
def forward(
143-
self,
144-
hidden_states: torch.Tensor,
145-
output_size: Optional[int] = None,
146-
scale: float = 1.0,
147-
) -> torch.Tensor:
148-
assert hidden_states.shape[1] == self.channels
95+
out_channels = out_channels or in_channels
96+
self.upsample_factor = upsample_factor
14997

150-
if self.norm is not None:
151-
assert False, "Not Implement yet"
152-
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
98+
self.conv = CausalConv3d(in_channels, out_channels, 3, 1, bias=bias)
15399

154-
if self.use_conv_transpose:
155-
return self.conv(hidden_states)
156-
157-
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
158-
dtype = hidden_states.dtype
159-
if dtype == torch.bfloat16:
160-
hidden_states = hidden_states.to(torch.float32)
161-
162-
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
163-
if hidden_states.shape[0] >= 64:
164-
hidden_states = hidden_states.contiguous()
165-
166-
# if `output_size` is passed we force the interpolation output
167-
# size and do not make use of `scale_factor=2`
168-
if self.interpolate:
169-
B, C, T, H, W = hidden_states.shape
170-
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
171-
if output_size is None:
172-
if T > 1:
173-
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
174-
175-
first_h = first_h.squeeze(2)
176-
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
177-
first_h = first_h.unsqueeze(2)
178-
else:
179-
assert False, "Not Implement yet"
180-
other_h = F.interpolate(other_h, size=output_size, mode="nearest")
100+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101+
num_frames = hidden_states.size(2)
102+
first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
181103

182-
if T > 1:
183-
hidden_states = torch.cat((first_h, other_h), dim=2)
184-
else:
185-
hidden_states = first_h
104+
first_frame = F.interpolate(
105+
first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
106+
).unsqueeze(2)
186107

187-
# If the input is bfloat16, we cast back to bfloat16
188-
if dtype == torch.bfloat16:
189-
hidden_states = hidden_states.to(dtype)
108+
if num_frames > 1:
109+
other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
110+
hidden_states = torch.cat((first_frame, other_frames), dim=2)
111+
else:
112+
hidden_states = first_frame
190113

191-
if self.use_conv:
192-
if self.name == "conv":
193-
hidden_states = self.conv(hidden_states)
194-
else:
195-
hidden_states = self.Conv2d_0(hidden_states)
114+
hidden_states = self.conv(hidden_states)
196115

197116
return hidden_states
198117

@@ -278,13 +197,10 @@ def __init__(
278197
eps: float = 1e-6,
279198
non_linearity: str = "swish",
280199
skip_time_act: bool = False,
281-
# default, scale_shift, ada_group, spatial
282200
time_embedding_norm: str = "default",
283201
kernel: Optional[torch.Tensor] = None,
284202
output_scale_factor: float = 1.0,
285203
use_in_shortcut: Optional[bool] = None,
286-
up: bool = False,
287-
down: bool = False,
288204
conv_shortcut_bias: bool = True,
289205
conv_3d_out_channels: Optional[int] = None,
290206
):
@@ -295,8 +211,6 @@ def __init__(
295211
out_channels = in_channels if out_channels is None else out_channels
296212
self.out_channels = out_channels
297213
self.use_conv_shortcut = conv_shortcut
298-
self.up = up
299-
self.down = down
300214
self.output_scale_factor = output_scale_factor
301215
self.time_embedding_norm = time_embedding_norm
302216
self.skip_time_act = skip_time_act
@@ -340,12 +254,6 @@ def __init__(
340254

341255
self.nonlinearity = get_activation(non_linearity)
342256

343-
self.upsample = self.downsample = None
344-
if self.up:
345-
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
346-
elif self.down:
347-
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
348-
349257
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
350258

351259
self.conv_shortcut = None
@@ -372,18 +280,6 @@ def forward(
372280
hidden_states = self.norm1(hidden_states)
373281

374282
hidden_states = self.nonlinearity(hidden_states)
375-
376-
if self.upsample is not None:
377-
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
378-
if hidden_states.shape[0] >= 64:
379-
input_tensor = input_tensor.contiguous()
380-
hidden_states = hidden_states.contiguous()
381-
input_tensor = self.upsample(input_tensor, scale=scale)
382-
hidden_states = self.upsample(hidden_states, scale=scale)
383-
elif self.downsample is not None:
384-
input_tensor = self.downsample(input_tensor, scale=scale)
385-
hidden_states = self.downsample(hidden_states, scale=scale)
386-
387283
hidden_states = self.conv1(hidden_states)
388284

389285
if self.time_emb_proj is not None:
@@ -461,12 +357,6 @@ def __init__(
461357
]
462358
attentions = []
463359

464-
if attention_head_dim is None:
465-
logger.warn(
466-
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
467-
)
468-
attention_head_dim = in_channels
469-
470360
for _ in range(num_layers):
471361
if self.add_attention:
472362
# assert False, "Not implemented yet"
@@ -634,7 +524,6 @@ def __init__(
634524
[
635525
UpsampleCausal3D(
636526
out_channels,
637-
use_conv=True,
638527
out_channels=out_channels,
639528
upsample_factor=upsample_scale_factor,
640529
)
@@ -662,12 +551,17 @@ class EncoderCausal3D(nn.Module):
662551
r"""
663552
Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
664553
"""
665-
554+
666555
def __init__(
667556
self,
668557
in_channels: int = 3,
669558
out_channels: int = 3,
670-
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D"),
559+
down_block_types: Tuple[str, ...] = (
560+
"DownEncoderBlockCausal3D",
561+
"DownEncoderBlockCausal3D",
562+
"DownEncoderBlockCausal3D",
563+
"DownEncoderBlockCausal3D",
564+
),
671565
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
672566
layers_per_block: int = 2,
673567
norm_num_groups: int = 32,
@@ -678,7 +572,7 @@ def __init__(
678572
spatial_compression_ratio: int = 8,
679573
) -> None:
680574
super().__init__()
681-
575+
682576
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
683577
self.mid_block = None
684578
self.down_blocks = nn.ModuleList([])
@@ -717,7 +611,7 @@ def __init__(
717611
resnet_groups=norm_num_groups,
718612
downsample_padding=0,
719613
)
720-
614+
721615
self.down_blocks.append(down_block)
722616

723617
self.mid_block = UNetMidBlockCausal3D(
@@ -778,7 +672,12 @@ def __init__(
778672
self,
779673
in_channels: int = 3,
780674
out_channels: int = 3,
781-
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D"),
675+
up_block_types: Tuple[str, ...] = (
676+
"UpDecoderBlockCausal3D",
677+
"UpDecoderBlockCausal3D",
678+
"UpDecoderBlockCausal3D",
679+
"UpDecoderBlockCausal3D",
680+
),
782681
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
783682
layers_per_block: int = 2,
784683
norm_num_groups: int = 32,
@@ -831,7 +730,7 @@ def __init__(
831730
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
832731
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
833732
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
834-
733+
835734
up_block = UpDecoderBlockCausal3D(
836735
num_layers=self.layers_per_block + 1,
837736
in_channels=prev_output_channel,
@@ -844,7 +743,7 @@ def __init__(
844743
resnet_time_scale_shift=norm_type,
845744
temb_channels=temb_channels,
846745
)
847-
746+
848747
self.up_blocks.append(up_block)
849748
prev_output_channel = output_channel
850749

@@ -923,8 +822,8 @@ class DecoderOutput2(BaseOutput):
923822

924823
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
925824
r"""
926-
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced
927-
in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
825+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
826+
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
928827
929828
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
930829
for all models (such as downloading or saving).
@@ -1119,9 +1018,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11191018
return DecoderOutput(sample=dec)
11201019

11211020
@apply_forward_hook
1122-
def decode(
1123-
self, z: torch.Tensor, return_dict: bool = True
1124-
) -> Union[DecoderOutput, torch.Tensor]:
1021+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
11251022
"""
11261023
Decode a batch of images/videos.
11271024
@@ -1229,9 +1126,7 @@ def spatial_tiled_encode(
12291126

12301127
return AutoencoderKLOutput(latent_dist=posterior)
12311128

1232-
def spatial_tiled_decode(
1233-
self, z: torch.Tensor, return_dict: bool = True
1234-
) -> Union[DecoderOutput, torch.Tensor]:
1129+
def spatial_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
12351130
r"""
12361131
Decode a batch of images/videos using a tiled decoder.
12371132
@@ -1315,9 +1210,7 @@ def temporal_tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Au
13151210

13161211
return AutoencoderKLOutput(latent_dist=posterior)
13171212

1318-
def temporal_tiled_decode(
1319-
self, z: torch.Tensor, return_dict: bool = True
1320-
) -> Union[DecoderOutput, torch.Tensor]:
1213+
def temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
13211214
# Split z into overlapping tiles and decode them separately.
13221215

13231216
B, C, T, H, W = z.shape

0 commit comments

Comments
 (0)