Skip to content

Commit 1e6ada6

Browse files
committed
add tests
1 parent 655dcda commit 1e6ada6

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from PIL import Image
21+
from transformers import (
22+
CLIPImageProcessor,
23+
CLIPTextConfig,
24+
CLIPTextModel,
25+
CLIPTokenizer,
26+
LlamaConfig,
27+
LlamaModel,
28+
LlamaTokenizer,
29+
)
30+
31+
from diffusers import (
32+
AutoencoderKLHunyuanVideo,
33+
FlowMatchEulerDiscreteScheduler,
34+
HunyuanVideoImageToVideoPipeline,
35+
HunyuanVideoTransformer3DModel,
36+
)
37+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
38+
39+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
40+
41+
42+
enable_full_determinism()
43+
44+
45+
class HunyuanVideoImageToVideoPipelineFastTests(
46+
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
47+
):
48+
pipeline_class = HunyuanVideoImageToVideoPipeline
49+
params = frozenset(
50+
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
51+
)
52+
batch_params = frozenset(["prompt", "image"])
53+
required_optional_params = frozenset(
54+
[
55+
"num_inference_steps",
56+
"generator",
57+
"latents",
58+
"return_dict",
59+
"callback_on_step_end",
60+
"callback_on_step_end_tensor_inputs",
61+
]
62+
)
63+
supports_dduf = False
64+
65+
# there is no xformers processor for Flux
66+
test_xformers_attention = False
67+
test_layerwise_casting = True
68+
test_group_offloading = True
69+
70+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
71+
torch.manual_seed(0)
72+
transformer = HunyuanVideoTransformer3DModel(
73+
in_channels=2 * 4 + 1,
74+
out_channels=4,
75+
num_attention_heads=2,
76+
attention_head_dim=10,
77+
num_layers=num_layers,
78+
num_single_layers=num_single_layers,
79+
num_refiner_layers=1,
80+
patch_size=1,
81+
patch_size_t=1,
82+
guidance_embeds=False,
83+
text_embed_dim=16,
84+
pooled_projection_dim=8,
85+
rope_axes_dim=(2, 4, 4),
86+
)
87+
88+
torch.manual_seed(0)
89+
vae = AutoencoderKLHunyuanVideo(
90+
in_channels=3,
91+
out_channels=3,
92+
latent_channels=4,
93+
down_block_types=(
94+
"HunyuanVideoDownBlock3D",
95+
"HunyuanVideoDownBlock3D",
96+
"HunyuanVideoDownBlock3D",
97+
"HunyuanVideoDownBlock3D",
98+
),
99+
up_block_types=(
100+
"HunyuanVideoUpBlock3D",
101+
"HunyuanVideoUpBlock3D",
102+
"HunyuanVideoUpBlock3D",
103+
"HunyuanVideoUpBlock3D",
104+
),
105+
block_out_channels=(8, 8, 8, 8),
106+
layers_per_block=1,
107+
act_fn="silu",
108+
norm_num_groups=4,
109+
scaling_factor=0.476986,
110+
spatial_compression_ratio=8,
111+
temporal_compression_ratio=4,
112+
mid_block_add_attention=True,
113+
)
114+
115+
torch.manual_seed(0)
116+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
117+
118+
llama_text_encoder_config = LlamaConfig(
119+
bos_token_id=0,
120+
eos_token_id=2,
121+
hidden_size=16,
122+
intermediate_size=37,
123+
layer_norm_eps=1e-05,
124+
num_attention_heads=4,
125+
num_hidden_layers=2,
126+
pad_token_id=1,
127+
vocab_size=1000,
128+
hidden_act="gelu",
129+
projection_dim=32,
130+
)
131+
clip_text_encoder_config = CLIPTextConfig(
132+
bos_token_id=0,
133+
eos_token_id=2,
134+
hidden_size=8,
135+
intermediate_size=37,
136+
layer_norm_eps=1e-05,
137+
num_attention_heads=4,
138+
num_hidden_layers=2,
139+
pad_token_id=1,
140+
vocab_size=1000,
141+
hidden_act="gelu",
142+
projection_dim=32,
143+
)
144+
145+
torch.manual_seed(0)
146+
text_encoder = LlamaModel(llama_text_encoder_config)
147+
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
148+
149+
torch.manual_seed(0)
150+
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
151+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
152+
153+
torch.manual_seed(0)
154+
image_processor = CLIPImageProcessor(
155+
crop_size=224,
156+
do_center_crop=True,
157+
do_normalize=True,
158+
do_resize=True,
159+
image_mean=[0.48145466, 0.4578275, 0.40821073],
160+
image_std=[0.26862954, 0.26130258, 0.27577711],
161+
resample=3,
162+
size=224,
163+
)
164+
165+
components = {
166+
"transformer": transformer,
167+
"vae": vae,
168+
"scheduler": scheduler,
169+
"text_encoder": text_encoder,
170+
"text_encoder_2": text_encoder_2,
171+
"tokenizer": tokenizer,
172+
"tokenizer_2": tokenizer_2,
173+
"image_processor": image_processor,
174+
}
175+
return components
176+
177+
def get_dummy_inputs(self, device, seed=0):
178+
if str(device).startswith("mps"):
179+
generator = torch.manual_seed(seed)
180+
else:
181+
generator = torch.Generator(device=device).manual_seed(seed)
182+
183+
image_height = 16
184+
image_width = 16
185+
image = Image.new("RGB", (image_width, image_height))
186+
inputs = {
187+
"image": image,
188+
"prompt": "dance monkey",
189+
"prompt_template": {
190+
"template": "{}",
191+
"crop_start": 0,
192+
},
193+
"generator": generator,
194+
"num_inference_steps": 2,
195+
"guidance_scale": 4.5,
196+
"height": image_height,
197+
"width": image_width,
198+
"num_frames": 9,
199+
"max_sequence_length": 16,
200+
"output_type": "pt",
201+
}
202+
return inputs
203+
204+
def test_inference(self):
205+
device = "cpu"
206+
207+
components = self.get_dummy_components()
208+
pipe = self.pipeline_class(**components)
209+
pipe.to(device)
210+
pipe.set_progress_bar_config(disable=None)
211+
212+
inputs = self.get_dummy_inputs(device)
213+
video = pipe(**inputs).frames
214+
generated_video = video[0]
215+
216+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
217+
expected_video = torch.randn(9, 3, 16, 16)
218+
max_diff = np.abs(generated_video - expected_video).max()
219+
self.assertLessEqual(max_diff, 1e10)
220+
221+
def test_callback_inputs(self):
222+
sig = inspect.signature(self.pipeline_class.__call__)
223+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
224+
has_callback_step_end = "callback_on_step_end" in sig.parameters
225+
226+
if not (has_callback_tensor_inputs and has_callback_step_end):
227+
return
228+
229+
components = self.get_dummy_components()
230+
pipe = self.pipeline_class(**components)
231+
pipe = pipe.to(torch_device)
232+
pipe.set_progress_bar_config(disable=None)
233+
self.assertTrue(
234+
hasattr(pipe, "_callback_tensor_inputs"),
235+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
236+
)
237+
238+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
239+
# iterate over callback args
240+
for tensor_name, tensor_value in callback_kwargs.items():
241+
# check that we're only passing in allowed tensor inputs
242+
assert tensor_name in pipe._callback_tensor_inputs
243+
244+
return callback_kwargs
245+
246+
def callback_inputs_all(pipe, i, t, callback_kwargs):
247+
for tensor_name in pipe._callback_tensor_inputs:
248+
assert tensor_name in callback_kwargs
249+
250+
# iterate over callback args
251+
for tensor_name, tensor_value in callback_kwargs.items():
252+
# check that we're only passing in allowed tensor inputs
253+
assert tensor_name in pipe._callback_tensor_inputs
254+
255+
return callback_kwargs
256+
257+
inputs = self.get_dummy_inputs(torch_device)
258+
259+
# Test passing in a subset
260+
inputs["callback_on_step_end"] = callback_inputs_subset
261+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
262+
output = pipe(**inputs)[0]
263+
264+
# Test passing in a everything
265+
inputs["callback_on_step_end"] = callback_inputs_all
266+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
267+
output = pipe(**inputs)[0]
268+
269+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
270+
is_last = i == (pipe.num_timesteps - 1)
271+
if is_last:
272+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
273+
return callback_kwargs
274+
275+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
276+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
277+
output = pipe(**inputs)[0]
278+
assert output.abs().sum() < 1e10
279+
280+
def test_attention_slicing_forward_pass(
281+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
282+
):
283+
if not self.test_attention_slicing:
284+
return
285+
286+
components = self.get_dummy_components()
287+
pipe = self.pipeline_class(**components)
288+
for component in pipe.components.values():
289+
if hasattr(component, "set_default_attn_processor"):
290+
component.set_default_attn_processor()
291+
pipe.to(torch_device)
292+
pipe.set_progress_bar_config(disable=None)
293+
294+
generator_device = "cpu"
295+
inputs = self.get_dummy_inputs(generator_device)
296+
output_without_slicing = pipe(**inputs)[0]
297+
298+
pipe.enable_attention_slicing(slice_size=1)
299+
inputs = self.get_dummy_inputs(generator_device)
300+
output_with_slicing1 = pipe(**inputs)[0]
301+
302+
pipe.enable_attention_slicing(slice_size=2)
303+
inputs = self.get_dummy_inputs(generator_device)
304+
output_with_slicing2 = pipe(**inputs)[0]
305+
306+
if test_max_difference:
307+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
308+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
309+
self.assertLess(
310+
max(max_diff1, max_diff2),
311+
expected_max_diff,
312+
"Attention slicing should not affect the inference results",
313+
)
314+
315+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
316+
# Seems to require higher tolerance than the other tests
317+
expected_diff_max = 0.6
318+
generator_device = "cpu"
319+
components = self.get_dummy_components()
320+
321+
pipe = self.pipeline_class(**components)
322+
pipe.to("cpu")
323+
pipe.set_progress_bar_config(disable=None)
324+
325+
# Without tiling
326+
inputs = self.get_dummy_inputs(generator_device)
327+
inputs["height"] = inputs["width"] = 128
328+
output_without_tiling = pipe(**inputs)[0]
329+
330+
# With tiling
331+
pipe.vae.enable_tiling(
332+
tile_sample_min_height=96,
333+
tile_sample_min_width=96,
334+
tile_sample_stride_height=64,
335+
tile_sample_stride_width=64,
336+
)
337+
inputs = self.get_dummy_inputs(generator_device)
338+
inputs["height"] = inputs["width"] = 128
339+
output_with_tiling = pipe(**inputs)[0]
340+
341+
self.assertLess(
342+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
343+
expected_diff_max,
344+
"VAE tiling should not affect the inference results",
345+
)
346+
347+
# TODO(aryan): Create a dummy gemma model with smol vocab size
348+
@unittest.skip(
349+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
350+
)
351+
def test_inference_batch_consistent(self):
352+
pass
353+
354+
@unittest.skip(
355+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
356+
)
357+
def test_inference_batch_single_identical(self):
358+
pass
359+
360+
@unittest.skip(
361+
"Encode prompt currently does not work in isolation because of requiring image embeddings from image processor. The test does not handle this case, or we need to rewrite encode_prompt."
362+
)
363+
def test_encode_prompt_works_in_isolation(self):
364+
pass

0 commit comments

Comments
 (0)