Skip to content

Commit f8c1e4a

Browse files
a-r-r-o-wFoundsheep
authored andcommitted
Hunyuan VAE tiling fixes and transformer docs (huggingface#10295)
* update * udpate * fix test
1 parent 7637180 commit f8c1e4a

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -792,12 +792,12 @@ def __init__(
792792
# The minimal tile height and width for spatial tiling to be used
793793
self.tile_sample_min_height = 256
794794
self.tile_sample_min_width = 256
795-
self.tile_sample_min_num_frames = 64
795+
self.tile_sample_min_num_frames = 16
796796

797797
# The minimal distance between two spatial tiles
798798
self.tile_sample_stride_height = 192
799799
self.tile_sample_stride_width = 192
800-
self.tile_sample_stride_num_frames = 48
800+
self.tile_sample_stride_num_frames = 12
801801

802802
def _set_gradient_checkpointing(self, module, value=False):
803803
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
@@ -1003,7 +1003,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
10031003
for i in range(0, height, self.tile_sample_stride_height):
10041004
row = []
10051005
for j in range(0, width, self.tile_sample_stride_width):
1006-
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
1006+
tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
10071007
tile = self.encoder(tile)
10081008
tile = self.quant_conv(tile)
10091009
row.append(tile)
@@ -1020,7 +1020,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
10201020
if j > 0:
10211021
tile = self.blend_h(row[j - 1], tile, blend_width)
10221022
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1023-
result_rows.append(torch.cat(result_row, dim=-1))
1023+
result_rows.append(torch.cat(result_row, dim=4))
10241024

10251025
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
10261026
return enc

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,46 @@ def forward(
497497

498498

499499
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin):
500+
r"""
501+
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
502+
503+
Args:
504+
in_channels (`int`, defaults to `16`):
505+
The number of channels in the input.
506+
out_channels (`int`, defaults to `16`):
507+
The number of channels in the output.
508+
num_attention_heads (`int`, defaults to `24`):
509+
The number of heads to use for multi-head attention.
510+
attention_head_dim (`int`, defaults to `128`):
511+
The number of channels in each head.
512+
num_layers (`int`, defaults to `20`):
513+
The number of layers of dual-stream blocks to use.
514+
num_single_layers (`int`, defaults to `40`):
515+
The number of layers of single-stream blocks to use.
516+
num_refiner_layers (`int`, defaults to `2`):
517+
The number of layers of refiner blocks to use.
518+
mlp_ratio (`float`, defaults to `4.0`):
519+
The ratio of the hidden layer size to the input size in the feedforward network.
520+
patch_size (`int`, defaults to `2`):
521+
The size of the spatial patches to use in the patch embedding layer.
522+
patch_size_t (`int`, defaults to `1`):
523+
The size of the tmeporal patches to use in the patch embedding layer.
524+
qk_norm (`str`, defaults to `rms_norm`):
525+
The normalization to use for the query and key projections in the attention layers.
526+
guidance_embeds (`bool`, defaults to `True`):
527+
Whether to use guidance embeddings in the model.
528+
text_embed_dim (`int`, defaults to `4096`):
529+
Input dimension of text embeddings from the text encoder.
530+
pooled_projection_dim (`int`, defaults to `768`):
531+
The dimension of the pooled projection of the text embeddings.
532+
rope_theta (`float`, defaults to `256.0`):
533+
The value of theta to use in the RoPE layer.
534+
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
535+
The dimensions of the axes to use in the RoPE layer.
536+
"""
537+
538+
_supports_gradient_checkpointing = True
539+
500540
@register_to_config
501541
def __init__(
502542
self,

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self):
4343
"down_block_types": (
4444
"HunyuanVideoDownBlock3D",
4545
"HunyuanVideoDownBlock3D",
46+
"HunyuanVideoDownBlock3D",
47+
"HunyuanVideoDownBlock3D",
4648
),
4749
"up_block_types": (
4850
"HunyuanVideoUpBlock3D",
4951
"HunyuanVideoUpBlock3D",
52+
"HunyuanVideoUpBlock3D",
53+
"HunyuanVideoUpBlock3D",
5054
),
5155
"block_out_channels": (8, 8, 8, 8),
5256
"layers_per_block": 1,
@@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self):
154158
}
155159
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
156160

161+
# We need to overwrite this test because the base test does not account length of down_block_types
162+
def test_forward_with_norm_groups(self):
163+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
164+
165+
init_dict["norm_num_groups"] = 16
166+
init_dict["block_out_channels"] = (16, 16, 16, 16)
167+
168+
model = self.model_class(**init_dict)
169+
model.to(torch_device)
170+
model.eval()
171+
172+
with torch.no_grad():
173+
output = model(**inputs_dict)
174+
175+
if isinstance(output, dict):
176+
output = output.to_tuple()[0]
177+
178+
self.assertIsNotNone(output)
179+
expected_shape = inputs_dict["sample"].shape
180+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
181+
157182
@unittest.skip("Unsupported test.")
158183
def test_outputs_equivalence(self):
159184
pass

0 commit comments

Comments
 (0)