Skip to content

Commit da07d86

Browse files
committed
add tests for framepack transformer model.
1 parent 2794029 commit da07d86

File tree

2 files changed

+35
-263
lines changed

2 files changed

+35
-263
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
if has_clean_x_embedder:
194194
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
195195

196-
self.use_gradient_checkpointing = False
196+
self.gradient_checkpointing = False
197197

198198
def forward(
199199
self,

tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py

Lines changed: 34 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616

1717
import torch
1818

19-
from diffusers import HunyuanVideoTransformer3DModel
19+
from diffusers import HunyuanVideoFramepackTransformer3DModel
2020
from diffusers.utils.testing_utils import (
2121
enable_full_determinism,
22-
is_torch_compile,
23-
require_torch_2,
24-
require_torch_gpu,
25-
slow,
2622
torch_device,
2723
)
2824

@@ -33,107 +29,39 @@
3329

3430

3531
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
36-
model_class = HunyuanVideoTransformer3DModel
32+
model_class = HunyuanVideoFramepackTransformer3DModel
3733
main_input_name = "hidden_states"
3834
uses_custom_attn_processor = True
35+
model_split_percents = [0.5, 0.7, 0.9]
3936

4037
@property
4138
def dummy_input(self):
4239
batch_size = 1
4340
num_channels = 4
44-
num_frames = 1
45-
height = 16
46-
width = 16
41+
num_frames = 3
42+
height = 4
43+
width = 4
4744
text_encoder_embedding_dim = 16
45+
image_encoder_embedding_dim = 16
4846
pooled_projection_dim = 8
4947
sequence_length = 12
5048

5149
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
52-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
5350
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
5451
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
5552
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
56-
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
57-
58-
return {
59-
"hidden_states": hidden_states,
60-
"timestep": timestep,
61-
"encoder_hidden_states": encoder_hidden_states,
62-
"pooled_projections": pooled_projections,
63-
"encoder_attention_mask": encoder_attention_mask,
64-
"guidance": guidance,
65-
}
66-
67-
@property
68-
def input_shape(self):
69-
return (4, 1, 16, 16)
70-
71-
@property
72-
def output_shape(self):
73-
return (4, 1, 16, 16)
74-
75-
def prepare_init_args_and_inputs_for_common(self):
76-
init_dict = {
77-
"in_channels": 4,
78-
"out_channels": 4,
79-
"num_attention_heads": 2,
80-
"attention_head_dim": 10,
81-
"num_layers": 1,
82-
"num_single_layers": 1,
83-
"num_refiner_layers": 1,
84-
"patch_size": 1,
85-
"patch_size_t": 1,
86-
"guidance_embeds": True,
87-
"text_embed_dim": 16,
88-
"pooled_projection_dim": 8,
89-
"rope_axes_dim": (2, 4, 4),
90-
"image_condition_type": None,
91-
}
92-
inputs_dict = self.dummy_input
93-
return init_dict, inputs_dict
94-
95-
def test_gradient_checkpointing_is_applied(self):
96-
expected_set = {"HunyuanVideoTransformer3DModel"}
97-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
98-
99-
@require_torch_gpu
100-
@require_torch_2
101-
@is_torch_compile
102-
@slow
103-
def test_torch_compile_recompilation_and_graph_break(self):
104-
torch._dynamo.reset()
105-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
106-
107-
model = self.model_class(**init_dict).to(torch_device)
108-
model = torch.compile(model, fullgraph=True)
109-
110-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
111-
_ = model(**inputs_dict)
112-
_ = model(**inputs_dict)
113-
114-
115-
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
116-
model_class = HunyuanVideoTransformer3DModel
117-
main_input_name = "hidden_states"
118-
uses_custom_attn_processor = True
119-
120-
@property
121-
def dummy_input(self):
122-
batch_size = 1
123-
num_channels = 8
124-
num_frames = 1
125-
height = 16
126-
width = 16
127-
text_encoder_embedding_dim = 16
128-
pooled_projection_dim = 8
129-
sequence_length = 12
130-
131-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
53+
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
54+
indices_latents = torch.ones((3,)).to(torch_device)
55+
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
56+
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
57+
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
58+
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
59+
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
60+
torch_device
61+
)
62+
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
13263
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
133-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
134-
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
135-
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
136-
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
64+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
13765

13866
return {
13967
"hidden_states": hidden_states,
@@ -142,203 +70,47 @@ def dummy_input(self):
14270
"pooled_projections": pooled_projections,
14371
"encoder_attention_mask": encoder_attention_mask,
14472
"guidance": guidance,
73+
"image_embeds": image_embeds,
74+
"indices_latents": indices_latents,
75+
"latents_clean": latents_clean,
76+
"indices_latents_clean": indices_latents_clean,
77+
"latents_history_2x": latents_history_2x,
78+
"indices_latents_history_2x": indices_latents_history_2x,
79+
"latents_history_4x": latents_history_4x,
80+
"indices_latents_history_4x": indices_latents_history_4x,
14581
}
14682

14783
@property
14884
def input_shape(self):
149-
return (8, 1, 16, 16)
85+
return (4, 3, 4, 4)
15086

15187
@property
15288
def output_shape(self):
153-
return (4, 1, 16, 16)
89+
return (4, 3, 4, 4)
15490

15591
def prepare_init_args_and_inputs_for_common(self):
15692
init_dict = {
157-
"in_channels": 8,
93+
"in_channels": 4,
15894
"out_channels": 4,
15995
"num_attention_heads": 2,
16096
"attention_head_dim": 10,
16197
"num_layers": 1,
16298
"num_single_layers": 1,
16399
"num_refiner_layers": 1,
164-
"patch_size": 1,
100+
"patch_size": 2,
165101
"patch_size_t": 1,
166102
"guidance_embeds": True,
167103
"text_embed_dim": 16,
168104
"pooled_projection_dim": 8,
169105
"rope_axes_dim": (2, 4, 4),
170106
"image_condition_type": None,
107+
"has_image_proj": True,
108+
"image_proj_dim": 16,
109+
"has_clean_x_embedder": True,
171110
}
172111
inputs_dict = self.dummy_input
173112
return init_dict, inputs_dict
174113

175-
def test_output(self):
176-
super().test_output(expected_output_shape=(1, *self.output_shape))
177-
178-
def test_gradient_checkpointing_is_applied(self):
179-
expected_set = {"HunyuanVideoTransformer3DModel"}
180-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
181-
182-
@require_torch_gpu
183-
@require_torch_2
184-
@is_torch_compile
185-
@slow
186-
def test_torch_compile_recompilation_and_graph_break(self):
187-
torch._dynamo.reset()
188-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
189-
190-
model = self.model_class(**init_dict).to(torch_device)
191-
model = torch.compile(model, fullgraph=True)
192-
193-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
194-
_ = model(**inputs_dict)
195-
_ = model(**inputs_dict)
196-
197-
198-
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
199-
model_class = HunyuanVideoTransformer3DModel
200-
main_input_name = "hidden_states"
201-
uses_custom_attn_processor = True
202-
203-
@property
204-
def dummy_input(self):
205-
batch_size = 1
206-
num_channels = 2 * 4 + 1
207-
num_frames = 1
208-
height = 16
209-
width = 16
210-
text_encoder_embedding_dim = 16
211-
pooled_projection_dim = 8
212-
sequence_length = 12
213-
214-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
215-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
216-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
217-
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
218-
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
219-
220-
return {
221-
"hidden_states": hidden_states,
222-
"timestep": timestep,
223-
"encoder_hidden_states": encoder_hidden_states,
224-
"pooled_projections": pooled_projections,
225-
"encoder_attention_mask": encoder_attention_mask,
226-
}
227-
228-
@property
229-
def input_shape(self):
230-
return (8, 1, 16, 16)
231-
232-
@property
233-
def output_shape(self):
234-
return (4, 1, 16, 16)
235-
236-
def prepare_init_args_and_inputs_for_common(self):
237-
init_dict = {
238-
"in_channels": 2 * 4 + 1,
239-
"out_channels": 4,
240-
"num_attention_heads": 2,
241-
"attention_head_dim": 10,
242-
"num_layers": 1,
243-
"num_single_layers": 1,
244-
"num_refiner_layers": 1,
245-
"patch_size": 1,
246-
"patch_size_t": 1,
247-
"guidance_embeds": False,
248-
"text_embed_dim": 16,
249-
"pooled_projection_dim": 8,
250-
"rope_axes_dim": (2, 4, 4),
251-
"image_condition_type": "latent_concat",
252-
}
253-
inputs_dict = self.dummy_input
254-
return init_dict, inputs_dict
255-
256-
def test_output(self):
257-
super().test_output(expected_output_shape=(1, *self.output_shape))
258-
259114
def test_gradient_checkpointing_is_applied(self):
260-
expected_set = {"HunyuanVideoTransformer3DModel"}
115+
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
261116
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
262-
263-
@require_torch_gpu
264-
@require_torch_2
265-
@is_torch_compile
266-
@slow
267-
def test_torch_compile_recompilation_and_graph_break(self):
268-
torch._dynamo.reset()
269-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
270-
271-
model = self.model_class(**init_dict).to(torch_device)
272-
model = torch.compile(model, fullgraph=True)
273-
274-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
275-
_ = model(**inputs_dict)
276-
_ = model(**inputs_dict)
277-
278-
279-
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
280-
model_class = HunyuanVideoTransformer3DModel
281-
main_input_name = "hidden_states"
282-
uses_custom_attn_processor = True
283-
284-
@property
285-
def dummy_input(self):
286-
batch_size = 1
287-
num_channels = 2
288-
num_frames = 1
289-
height = 16
290-
width = 16
291-
text_encoder_embedding_dim = 16
292-
pooled_projection_dim = 8
293-
sequence_length = 12
294-
295-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
296-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
297-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
298-
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
299-
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
300-
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
301-
302-
return {
303-
"hidden_states": hidden_states,
304-
"timestep": timestep,
305-
"encoder_hidden_states": encoder_hidden_states,
306-
"pooled_projections": pooled_projections,
307-
"encoder_attention_mask": encoder_attention_mask,
308-
"guidance": guidance,
309-
}
310-
311-
@property
312-
def input_shape(self):
313-
return (8, 1, 16, 16)
314-
315-
@property
316-
def output_shape(self):
317-
return (4, 1, 16, 16)
318-
319-
def prepare_init_args_and_inputs_for_common(self):
320-
init_dict = {
321-
"in_channels": 2,
322-
"out_channels": 4,
323-
"num_attention_heads": 2,
324-
"attention_head_dim": 10,
325-
"num_layers": 1,
326-
"num_single_layers": 1,
327-
"num_refiner_layers": 1,
328-
"patch_size": 1,
329-
"patch_size_t": 1,
330-
"guidance_embeds": True,
331-
"text_embed_dim": 16,
332-
"pooled_projection_dim": 8,
333-
"rope_axes_dim": (2, 4, 4),
334-
"image_condition_type": "token_replace",
335-
}
336-
inputs_dict = self.dummy_input
337-
return init_dict, inputs_dict
338-
339-
def test_output(self):
340-
super().test_output(expected_output_shape=(1, *self.output_shape))
341-
342-
def test_gradient_checkpointing_is_applied(self):
343-
expected_set = {"HunyuanVideoTransformer3DModel"}
344-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)