Skip to content

Commit 2794029

Browse files
committed
start.
1 parent 53bd367 commit 2794029

File tree

1 file changed

+344
-0
lines changed

1 file changed

+344
-0
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import HunyuanVideoTransformer3DModel
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
is_torch_compile,
23+
require_torch_2,
24+
require_torch_gpu,
25+
slow,
26+
torch_device,
27+
)
28+
29+
from ..test_modeling_common import ModelTesterMixin
30+
31+
32+
enable_full_determinism()
33+
34+
35+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
36+
model_class = HunyuanVideoTransformer3DModel
37+
main_input_name = "hidden_states"
38+
uses_custom_attn_processor = True
39+
40+
@property
41+
def dummy_input(self):
42+
batch_size = 1
43+
num_channels = 4
44+
num_frames = 1
45+
height = 16
46+
width = 16
47+
text_encoder_embedding_dim = 16
48+
pooled_projection_dim = 8
49+
sequence_length = 12
50+
51+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
52+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
53+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
54+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
55+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
56+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
57+
58+
return {
59+
"hidden_states": hidden_states,
60+
"timestep": timestep,
61+
"encoder_hidden_states": encoder_hidden_states,
62+
"pooled_projections": pooled_projections,
63+
"encoder_attention_mask": encoder_attention_mask,
64+
"guidance": guidance,
65+
}
66+
67+
@property
68+
def input_shape(self):
69+
return (4, 1, 16, 16)
70+
71+
@property
72+
def output_shape(self):
73+
return (4, 1, 16, 16)
74+
75+
def prepare_init_args_and_inputs_for_common(self):
76+
init_dict = {
77+
"in_channels": 4,
78+
"out_channels": 4,
79+
"num_attention_heads": 2,
80+
"attention_head_dim": 10,
81+
"num_layers": 1,
82+
"num_single_layers": 1,
83+
"num_refiner_layers": 1,
84+
"patch_size": 1,
85+
"patch_size_t": 1,
86+
"guidance_embeds": True,
87+
"text_embed_dim": 16,
88+
"pooled_projection_dim": 8,
89+
"rope_axes_dim": (2, 4, 4),
90+
"image_condition_type": None,
91+
}
92+
inputs_dict = self.dummy_input
93+
return init_dict, inputs_dict
94+
95+
def test_gradient_checkpointing_is_applied(self):
96+
expected_set = {"HunyuanVideoTransformer3DModel"}
97+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
98+
99+
@require_torch_gpu
100+
@require_torch_2
101+
@is_torch_compile
102+
@slow
103+
def test_torch_compile_recompilation_and_graph_break(self):
104+
torch._dynamo.reset()
105+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
106+
107+
model = self.model_class(**init_dict).to(torch_device)
108+
model = torch.compile(model, fullgraph=True)
109+
110+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
111+
_ = model(**inputs_dict)
112+
_ = model(**inputs_dict)
113+
114+
115+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
116+
model_class = HunyuanVideoTransformer3DModel
117+
main_input_name = "hidden_states"
118+
uses_custom_attn_processor = True
119+
120+
@property
121+
def dummy_input(self):
122+
batch_size = 1
123+
num_channels = 8
124+
num_frames = 1
125+
height = 16
126+
width = 16
127+
text_encoder_embedding_dim = 16
128+
pooled_projection_dim = 8
129+
sequence_length = 12
130+
131+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
132+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
133+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
134+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
135+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
136+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
137+
138+
return {
139+
"hidden_states": hidden_states,
140+
"timestep": timestep,
141+
"encoder_hidden_states": encoder_hidden_states,
142+
"pooled_projections": pooled_projections,
143+
"encoder_attention_mask": encoder_attention_mask,
144+
"guidance": guidance,
145+
}
146+
147+
@property
148+
def input_shape(self):
149+
return (8, 1, 16, 16)
150+
151+
@property
152+
def output_shape(self):
153+
return (4, 1, 16, 16)
154+
155+
def prepare_init_args_and_inputs_for_common(self):
156+
init_dict = {
157+
"in_channels": 8,
158+
"out_channels": 4,
159+
"num_attention_heads": 2,
160+
"attention_head_dim": 10,
161+
"num_layers": 1,
162+
"num_single_layers": 1,
163+
"num_refiner_layers": 1,
164+
"patch_size": 1,
165+
"patch_size_t": 1,
166+
"guidance_embeds": True,
167+
"text_embed_dim": 16,
168+
"pooled_projection_dim": 8,
169+
"rope_axes_dim": (2, 4, 4),
170+
"image_condition_type": None,
171+
}
172+
inputs_dict = self.dummy_input
173+
return init_dict, inputs_dict
174+
175+
def test_output(self):
176+
super().test_output(expected_output_shape=(1, *self.output_shape))
177+
178+
def test_gradient_checkpointing_is_applied(self):
179+
expected_set = {"HunyuanVideoTransformer3DModel"}
180+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
181+
182+
@require_torch_gpu
183+
@require_torch_2
184+
@is_torch_compile
185+
@slow
186+
def test_torch_compile_recompilation_and_graph_break(self):
187+
torch._dynamo.reset()
188+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
189+
190+
model = self.model_class(**init_dict).to(torch_device)
191+
model = torch.compile(model, fullgraph=True)
192+
193+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
194+
_ = model(**inputs_dict)
195+
_ = model(**inputs_dict)
196+
197+
198+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
199+
model_class = HunyuanVideoTransformer3DModel
200+
main_input_name = "hidden_states"
201+
uses_custom_attn_processor = True
202+
203+
@property
204+
def dummy_input(self):
205+
batch_size = 1
206+
num_channels = 2 * 4 + 1
207+
num_frames = 1
208+
height = 16
209+
width = 16
210+
text_encoder_embedding_dim = 16
211+
pooled_projection_dim = 8
212+
sequence_length = 12
213+
214+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
215+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
216+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
217+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
218+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
219+
220+
return {
221+
"hidden_states": hidden_states,
222+
"timestep": timestep,
223+
"encoder_hidden_states": encoder_hidden_states,
224+
"pooled_projections": pooled_projections,
225+
"encoder_attention_mask": encoder_attention_mask,
226+
}
227+
228+
@property
229+
def input_shape(self):
230+
return (8, 1, 16, 16)
231+
232+
@property
233+
def output_shape(self):
234+
return (4, 1, 16, 16)
235+
236+
def prepare_init_args_and_inputs_for_common(self):
237+
init_dict = {
238+
"in_channels": 2 * 4 + 1,
239+
"out_channels": 4,
240+
"num_attention_heads": 2,
241+
"attention_head_dim": 10,
242+
"num_layers": 1,
243+
"num_single_layers": 1,
244+
"num_refiner_layers": 1,
245+
"patch_size": 1,
246+
"patch_size_t": 1,
247+
"guidance_embeds": False,
248+
"text_embed_dim": 16,
249+
"pooled_projection_dim": 8,
250+
"rope_axes_dim": (2, 4, 4),
251+
"image_condition_type": "latent_concat",
252+
}
253+
inputs_dict = self.dummy_input
254+
return init_dict, inputs_dict
255+
256+
def test_output(self):
257+
super().test_output(expected_output_shape=(1, *self.output_shape))
258+
259+
def test_gradient_checkpointing_is_applied(self):
260+
expected_set = {"HunyuanVideoTransformer3DModel"}
261+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
262+
263+
@require_torch_gpu
264+
@require_torch_2
265+
@is_torch_compile
266+
@slow
267+
def test_torch_compile_recompilation_and_graph_break(self):
268+
torch._dynamo.reset()
269+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
270+
271+
model = self.model_class(**init_dict).to(torch_device)
272+
model = torch.compile(model, fullgraph=True)
273+
274+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
275+
_ = model(**inputs_dict)
276+
_ = model(**inputs_dict)
277+
278+
279+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
280+
model_class = HunyuanVideoTransformer3DModel
281+
main_input_name = "hidden_states"
282+
uses_custom_attn_processor = True
283+
284+
@property
285+
def dummy_input(self):
286+
batch_size = 1
287+
num_channels = 2
288+
num_frames = 1
289+
height = 16
290+
width = 16
291+
text_encoder_embedding_dim = 16
292+
pooled_projection_dim = 8
293+
sequence_length = 12
294+
295+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
296+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
297+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
298+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
299+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
300+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
301+
302+
return {
303+
"hidden_states": hidden_states,
304+
"timestep": timestep,
305+
"encoder_hidden_states": encoder_hidden_states,
306+
"pooled_projections": pooled_projections,
307+
"encoder_attention_mask": encoder_attention_mask,
308+
"guidance": guidance,
309+
}
310+
311+
@property
312+
def input_shape(self):
313+
return (8, 1, 16, 16)
314+
315+
@property
316+
def output_shape(self):
317+
return (4, 1, 16, 16)
318+
319+
def prepare_init_args_and_inputs_for_common(self):
320+
init_dict = {
321+
"in_channels": 2,
322+
"out_channels": 4,
323+
"num_attention_heads": 2,
324+
"attention_head_dim": 10,
325+
"num_layers": 1,
326+
"num_single_layers": 1,
327+
"num_refiner_layers": 1,
328+
"patch_size": 1,
329+
"patch_size_t": 1,
330+
"guidance_embeds": True,
331+
"text_embed_dim": 16,
332+
"pooled_projection_dim": 8,
333+
"rope_axes_dim": (2, 4, 4),
334+
"image_condition_type": "token_replace",
335+
}
336+
inputs_dict = self.dummy_input
337+
return init_dict, inputs_dict
338+
339+
def test_output(self):
340+
super().test_output(expected_output_shape=(1, *self.output_shape))
341+
342+
def test_gradient_checkpointing_is_applied(self):
343+
expected_set = {"HunyuanVideoTransformer3DModel"}
344+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)