Skip to content

Commit a752c75

Browse files
committed
update
1 parent b3c7ce7 commit a752c75

File tree

3 files changed

+340
-2
lines changed

3 files changed

+340
-2
lines changed

docs/source/en/api/pipelines/hunyuan_video.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Recommendations for inference:
3232
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
3333
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
3434

35-
3635
## Available models
3736

3837
The following models are available for the [`HunyuanVideoPipeline`](text-to-video) pipeline:

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __init__(
219219

220220
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
221221
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
222+
self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
222223
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
223224

224225
# Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds
@@ -440,7 +441,7 @@ def prepare_latents(
440441
else:
441442
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
442443

443-
image_latents = torch.cat(image_latents, dim=0).to(dtype)
444+
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
444445

445446
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
446447
latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from PIL import Image
21+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer
22+
23+
from diffusers import (
24+
AutoencoderKLHunyuanVideo,
25+
FlowMatchEulerDiscreteScheduler,
26+
HunyuanSkyreelsImageToVideoPipeline,
27+
HunyuanVideoTransformer3DModel,
28+
)
29+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
30+
31+
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
32+
33+
34+
enable_full_determinism()
35+
36+
37+
class HunyuanSkyreelsImageToVideoPipelineFastTests(
38+
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
39+
):
40+
pipeline_class = HunyuanSkyreelsImageToVideoPipeline
41+
params = frozenset(
42+
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
43+
)
44+
batch_params = frozenset(["prompt", "image"])
45+
required_optional_params = frozenset(
46+
[
47+
"num_inference_steps",
48+
"generator",
49+
"latents",
50+
"return_dict",
51+
"callback_on_step_end",
52+
"callback_on_step_end_tensor_inputs",
53+
]
54+
)
55+
supports_dduf = False
56+
57+
# there is no xformers processor for Flux
58+
test_xformers_attention = False
59+
test_layerwise_casting = True
60+
test_group_offloading = True
61+
62+
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
63+
torch.manual_seed(0)
64+
transformer = HunyuanVideoTransformer3DModel(
65+
in_channels=8,
66+
out_channels=4,
67+
num_attention_heads=2,
68+
attention_head_dim=10,
69+
num_layers=num_layers,
70+
num_single_layers=num_single_layers,
71+
num_refiner_layers=1,
72+
patch_size=1,
73+
patch_size_t=1,
74+
guidance_embeds=True,
75+
text_embed_dim=16,
76+
pooled_projection_dim=8,
77+
rope_axes_dim=(2, 4, 4),
78+
)
79+
80+
torch.manual_seed(0)
81+
vae = AutoencoderKLHunyuanVideo(
82+
in_channels=3,
83+
out_channels=3,
84+
latent_channels=4,
85+
down_block_types=(
86+
"HunyuanVideoDownBlock3D",
87+
"HunyuanVideoDownBlock3D",
88+
"HunyuanVideoDownBlock3D",
89+
"HunyuanVideoDownBlock3D",
90+
),
91+
up_block_types=(
92+
"HunyuanVideoUpBlock3D",
93+
"HunyuanVideoUpBlock3D",
94+
"HunyuanVideoUpBlock3D",
95+
"HunyuanVideoUpBlock3D",
96+
),
97+
block_out_channels=(8, 8, 8, 8),
98+
layers_per_block=1,
99+
act_fn="silu",
100+
norm_num_groups=4,
101+
scaling_factor=0.476986,
102+
spatial_compression_ratio=8,
103+
temporal_compression_ratio=4,
104+
mid_block_add_attention=True,
105+
)
106+
107+
torch.manual_seed(0)
108+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
109+
110+
llama_text_encoder_config = LlamaConfig(
111+
bos_token_id=0,
112+
eos_token_id=2,
113+
hidden_size=16,
114+
intermediate_size=37,
115+
layer_norm_eps=1e-05,
116+
num_attention_heads=4,
117+
num_hidden_layers=2,
118+
pad_token_id=1,
119+
vocab_size=1000,
120+
hidden_act="gelu",
121+
projection_dim=32,
122+
)
123+
clip_text_encoder_config = CLIPTextConfig(
124+
bos_token_id=0,
125+
eos_token_id=2,
126+
hidden_size=8,
127+
intermediate_size=37,
128+
layer_norm_eps=1e-05,
129+
num_attention_heads=4,
130+
num_hidden_layers=2,
131+
pad_token_id=1,
132+
vocab_size=1000,
133+
hidden_act="gelu",
134+
projection_dim=32,
135+
)
136+
137+
torch.manual_seed(0)
138+
text_encoder = LlamaModel(llama_text_encoder_config)
139+
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
140+
141+
torch.manual_seed(0)
142+
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
143+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
144+
145+
components = {
146+
"transformer": transformer,
147+
"vae": vae,
148+
"scheduler": scheduler,
149+
"text_encoder": text_encoder,
150+
"text_encoder_2": text_encoder_2,
151+
"tokenizer": tokenizer,
152+
"tokenizer_2": tokenizer_2,
153+
}
154+
return components
155+
156+
def get_dummy_inputs(self, device, seed=0):
157+
if str(device).startswith("mps"):
158+
generator = torch.manual_seed(seed)
159+
else:
160+
generator = torch.Generator(device=device).manual_seed(seed)
161+
162+
image_height = 16
163+
image_width = 16
164+
image = Image.new("RGB", (image_width, image_height))
165+
inputs = {
166+
"image": image,
167+
"prompt": "dance monkey",
168+
"prompt_template": {
169+
"template": "{}",
170+
"crop_start": 0,
171+
},
172+
"generator": generator,
173+
"num_inference_steps": 2,
174+
"guidance_scale": 4.5,
175+
"height": 16,
176+
"width": 16,
177+
# 4 * k + 1 is the recommendation
178+
"num_frames": 9,
179+
"max_sequence_length": 16,
180+
"output_type": "pt",
181+
}
182+
return inputs
183+
184+
def test_inference(self):
185+
device = "cpu"
186+
187+
components = self.get_dummy_components()
188+
pipe = self.pipeline_class(**components)
189+
pipe.to(device)
190+
pipe.set_progress_bar_config(disable=None)
191+
192+
inputs = self.get_dummy_inputs(device)
193+
video = pipe(**inputs).frames
194+
generated_video = video[0]
195+
196+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
197+
expected_video = torch.randn(9, 3, 16, 16)
198+
max_diff = np.abs(generated_video - expected_video).max()
199+
self.assertLessEqual(max_diff, 1e10)
200+
201+
def test_callback_inputs(self):
202+
sig = inspect.signature(self.pipeline_class.__call__)
203+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
204+
has_callback_step_end = "callback_on_step_end" in sig.parameters
205+
206+
if not (has_callback_tensor_inputs and has_callback_step_end):
207+
return
208+
209+
components = self.get_dummy_components()
210+
pipe = self.pipeline_class(**components)
211+
pipe = pipe.to(torch_device)
212+
pipe.set_progress_bar_config(disable=None)
213+
self.assertTrue(
214+
hasattr(pipe, "_callback_tensor_inputs"),
215+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
216+
)
217+
218+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
219+
# iterate over callback args
220+
for tensor_name, tensor_value in callback_kwargs.items():
221+
# check that we're only passing in allowed tensor inputs
222+
assert tensor_name in pipe._callback_tensor_inputs
223+
224+
return callback_kwargs
225+
226+
def callback_inputs_all(pipe, i, t, callback_kwargs):
227+
for tensor_name in pipe._callback_tensor_inputs:
228+
assert tensor_name in callback_kwargs
229+
230+
# iterate over callback args
231+
for tensor_name, tensor_value in callback_kwargs.items():
232+
# check that we're only passing in allowed tensor inputs
233+
assert tensor_name in pipe._callback_tensor_inputs
234+
235+
return callback_kwargs
236+
237+
inputs = self.get_dummy_inputs(torch_device)
238+
239+
# Test passing in a subset
240+
inputs["callback_on_step_end"] = callback_inputs_subset
241+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
242+
output = pipe(**inputs)[0]
243+
244+
# Test passing in a everything
245+
inputs["callback_on_step_end"] = callback_inputs_all
246+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
247+
output = pipe(**inputs)[0]
248+
249+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
250+
is_last = i == (pipe.num_timesteps - 1)
251+
if is_last:
252+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
253+
return callback_kwargs
254+
255+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
256+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
257+
output = pipe(**inputs)[0]
258+
assert output.abs().sum() < 1e10
259+
260+
def test_attention_slicing_forward_pass(
261+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
262+
):
263+
if not self.test_attention_slicing:
264+
return
265+
266+
components = self.get_dummy_components()
267+
pipe = self.pipeline_class(**components)
268+
for component in pipe.components.values():
269+
if hasattr(component, "set_default_attn_processor"):
270+
component.set_default_attn_processor()
271+
pipe.to(torch_device)
272+
pipe.set_progress_bar_config(disable=None)
273+
274+
generator_device = "cpu"
275+
inputs = self.get_dummy_inputs(generator_device)
276+
output_without_slicing = pipe(**inputs)[0]
277+
278+
pipe.enable_attention_slicing(slice_size=1)
279+
inputs = self.get_dummy_inputs(generator_device)
280+
output_with_slicing1 = pipe(**inputs)[0]
281+
282+
pipe.enable_attention_slicing(slice_size=2)
283+
inputs = self.get_dummy_inputs(generator_device)
284+
output_with_slicing2 = pipe(**inputs)[0]
285+
286+
if test_max_difference:
287+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
288+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
289+
self.assertLess(
290+
max(max_diff1, max_diff2),
291+
expected_max_diff,
292+
"Attention slicing should not affect the inference results",
293+
)
294+
295+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
296+
# Seems to require higher tolerance than the other tests
297+
expected_diff_max = 0.6
298+
generator_device = "cpu"
299+
components = self.get_dummy_components()
300+
301+
pipe = self.pipeline_class(**components)
302+
pipe.to("cpu")
303+
pipe.set_progress_bar_config(disable=None)
304+
305+
# Without tiling
306+
inputs = self.get_dummy_inputs(generator_device)
307+
inputs["height"] = inputs["width"] = 128
308+
output_without_tiling = pipe(**inputs)[0]
309+
310+
# With tiling
311+
pipe.vae.enable_tiling(
312+
tile_sample_min_height=96,
313+
tile_sample_min_width=96,
314+
tile_sample_stride_height=64,
315+
tile_sample_stride_width=64,
316+
)
317+
inputs = self.get_dummy_inputs(generator_device)
318+
inputs["height"] = inputs["width"] = 128
319+
output_with_tiling = pipe(**inputs)[0]
320+
321+
self.assertLess(
322+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
323+
expected_diff_max,
324+
"VAE tiling should not affect the inference results",
325+
)
326+
327+
# TODO(aryan): Create a dummy gemma model with smol vocab size
328+
@unittest.skip(
329+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
330+
)
331+
def test_inference_batch_consistent(self):
332+
pass
333+
334+
@unittest.skip(
335+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
336+
)
337+
def test_inference_batch_single_identical(self):
338+
pass

0 commit comments

Comments
 (0)