Skip to content

Commit d0c61e0

Browse files
committed
autoencoder test
1 parent 2a72d20 commit d0c61e0

File tree

2 files changed

+171
-8
lines changed

2 files changed

+171
-8
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def __init__(
240240
self.attentions = nn.ModuleList(attentions)
241241
self.resnets = nn.ModuleList(resnets)
242242

243+
self.gradient_checkpointing = False
244+
243245
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
244246
if torch.is_grad_enabled() and self.gradient_checkpointing:
245247

@@ -336,6 +338,8 @@ def __init__(
336338
else:
337339
self.downsamplers = None
338340

341+
self.gradient_checkpointing = False
342+
339343
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
340344
if torch.is_grad_enabled() and self.gradient_checkpointing:
341345

@@ -410,6 +414,8 @@ def __init__(
410414
else:
411415
self.upsamplers = None
412416

417+
self.gradient_checkpointing = False
418+
413419
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
414420
if torch.is_grad_enabled() and self.gradient_checkpointing:
415421

@@ -440,7 +446,7 @@ def custom_forward(*inputs):
440446
return hidden_states
441447

442448

443-
class EncoderCausal3D(nn.Module):
449+
class HunyuanVideoEncoder3D(nn.Module):
444450
r"""
445451
Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
446452
"""
@@ -564,7 +570,7 @@ def custom_forward(*inputs):
564570
return hidden_states
565571

566572

567-
class DecoderCausal3D(nn.Module):
573+
class HunyuanVideoDecoder3D(nn.Module):
568574
r"""
569575
Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
570576
"""
@@ -730,7 +736,7 @@ def __init__(
730736

731737
self.time_compression_ratio = temporal_compression_ratio
732738

733-
self.encoder = EncoderCausal3D(
739+
self.encoder = HunyuanVideoEncoder3D(
734740
in_channels=in_channels,
735741
out_channels=latent_channels,
736742
down_block_types=down_block_types,
@@ -744,7 +750,7 @@ def __init__(
744750
spatial_compression_ratio=spatial_compression_ratio,
745751
)
746752

747-
self.decoder = DecoderCausal3D(
753+
self.decoder = HunyuanVideoDecoder3D(
748754
in_channels=latent_channels,
749755
out_channels=out_channels,
750756
up_block_types=up_block_types,
@@ -789,7 +795,7 @@ def __init__(
789795
self.tile_sample_stride_width = 192
790796

791797
def _set_gradient_checkpointing(self, module, value=False):
792-
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
798+
if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)):
793799
module.gradient_checkpointing = value
794800

795801
def enable_tiling(
@@ -1151,7 +1157,5 @@ def forward(
11511157
z = posterior.sample(generator=generator)
11521158
else:
11531159
z = posterior.mode()
1154-
dec = self.decode(z)
1155-
if not return_dict:
1156-
return (dec,)
1160+
dec = self.decode(z, return_dict=return_dict)
11571161
return dec
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import AutoencoderKLHunyuanVideo
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
floats_tensor,
24+
torch_device,
25+
)
26+
27+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
28+
29+
30+
enable_full_determinism()
31+
32+
33+
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
34+
model_class = AutoencoderKLHunyuanVideo
35+
main_input_name = "sample"
36+
base_precision = 1e-2
37+
38+
def get_autoencoder_kl_hunyuan_video_config(self):
39+
return {
40+
"in_channels": 3,
41+
"out_channels": 3,
42+
"latent_channels": 4,
43+
"down_block_types": (
44+
"HunyuanVideoDownBlock3D",
45+
"HunyuanVideoDownBlock3D",
46+
),
47+
"up_block_types": (
48+
"HunyuanVideoUpBlock3D",
49+
"HunyuanVideoUpBlock3D",
50+
),
51+
"block_out_channels": (8, 8, 8, 8),
52+
"layers_per_block": 1,
53+
"act_fn": "silu",
54+
"norm_num_groups": 4,
55+
"scaling_factor": 0.476986,
56+
"spatial_compression_ratio": 8,
57+
"temporal_compression_ratio": 4,
58+
"mid_block_add_attention": True,
59+
}
60+
61+
@property
62+
def dummy_input(self):
63+
batch_size = 2
64+
num_frames = 9
65+
num_channels = 3
66+
sizes = (16, 16)
67+
68+
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
69+
70+
return {"sample": image}
71+
72+
@property
73+
def input_shape(self):
74+
return (3, 9, 16, 16)
75+
76+
@property
77+
def output_shape(self):
78+
return (3, 9, 16, 16)
79+
80+
def prepare_init_args_and_inputs_for_common(self):
81+
init_dict = self.get_autoencoder_kl_hunyuan_video_config()
82+
inputs_dict = self.dummy_input
83+
return init_dict, inputs_dict
84+
85+
def test_enable_disable_tiling(self):
86+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
87+
88+
torch.manual_seed(0)
89+
model = self.model_class(**init_dict).to(torch_device)
90+
91+
inputs_dict.update({"return_dict": False})
92+
93+
torch.manual_seed(0)
94+
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
95+
96+
torch.manual_seed(0)
97+
model.enable_tiling()
98+
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
99+
100+
self.assertLess(
101+
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
102+
0.5,
103+
"VAE tiling should not affect the inference results",
104+
)
105+
106+
torch.manual_seed(0)
107+
model.disable_tiling()
108+
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
109+
110+
self.assertEqual(
111+
output_without_tiling.detach().cpu().numpy().all(),
112+
output_without_tiling_2.detach().cpu().numpy().all(),
113+
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
114+
)
115+
116+
def test_enable_disable_slicing(self):
117+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
118+
119+
torch.manual_seed(0)
120+
model = self.model_class(**init_dict).to(torch_device)
121+
122+
inputs_dict.update({"return_dict": False})
123+
124+
torch.manual_seed(0)
125+
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
126+
127+
torch.manual_seed(0)
128+
model.enable_slicing()
129+
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
130+
131+
self.assertLess(
132+
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
133+
0.5,
134+
"VAE slicing should not affect the inference results",
135+
)
136+
137+
torch.manual_seed(0)
138+
model.disable_slicing()
139+
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
140+
141+
self.assertEqual(
142+
output_without_slicing.detach().cpu().numpy().all(),
143+
output_without_slicing_2.detach().cpu().numpy().all(),
144+
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
145+
)
146+
147+
def test_gradient_checkpointing_is_applied(self):
148+
expected_set = {
149+
"HunyuanVideoDecoder3D",
150+
"HunyuanVideoDownBlock3D",
151+
"HunyuanVideoEncoder3D",
152+
"HunyuanVideoMidBlock3D",
153+
"HunyuanVideoUpBlock3D",
154+
}
155+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
156+
157+
@unittest.skip("Unsupported test.")
158+
def test_outputs_equivalence(self):
159+
pass

0 commit comments

Comments
 (0)