1+ # Copyright 2024 HuggingFace Inc.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import unittest
16+
17+ import torch
18+
19+ from diffusers import HunyuanVideoTransformer3DModel
20+ from diffusers .utils .testing_utils import (
21+ enable_full_determinism ,
22+ is_torch_compile ,
23+ require_torch_2 ,
24+ require_torch_gpu ,
25+ slow ,
26+ torch_device ,
27+ )
28+
29+ from ..test_modeling_common import ModelTesterMixin
30+
31+
32+ enable_full_determinism ()
33+
34+
35+ class HunyuanVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
36+ model_class = HunyuanVideoTransformer3DModel
37+ main_input_name = "hidden_states"
38+ uses_custom_attn_processor = True
39+
40+ @property
41+ def dummy_input (self ):
42+ batch_size = 1
43+ num_channels = 4
44+ num_frames = 1
45+ height = 16
46+ width = 16
47+ text_encoder_embedding_dim = 16
48+ pooled_projection_dim = 8
49+ sequence_length = 12
50+
51+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
52+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
53+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
54+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
55+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
56+ guidance = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
57+
58+ return {
59+ "hidden_states" : hidden_states ,
60+ "timestep" : timestep ,
61+ "encoder_hidden_states" : encoder_hidden_states ,
62+ "pooled_projections" : pooled_projections ,
63+ "encoder_attention_mask" : encoder_attention_mask ,
64+ "guidance" : guidance ,
65+ }
66+
67+ @property
68+ def input_shape (self ):
69+ return (4 , 1 , 16 , 16 )
70+
71+ @property
72+ def output_shape (self ):
73+ return (4 , 1 , 16 , 16 )
74+
75+ def prepare_init_args_and_inputs_for_common (self ):
76+ init_dict = {
77+ "in_channels" : 4 ,
78+ "out_channels" : 4 ,
79+ "num_attention_heads" : 2 ,
80+ "attention_head_dim" : 10 ,
81+ "num_layers" : 1 ,
82+ "num_single_layers" : 1 ,
83+ "num_refiner_layers" : 1 ,
84+ "patch_size" : 1 ,
85+ "patch_size_t" : 1 ,
86+ "guidance_embeds" : True ,
87+ "text_embed_dim" : 16 ,
88+ "pooled_projection_dim" : 8 ,
89+ "rope_axes_dim" : (2 , 4 , 4 ),
90+ "image_condition_type" : None ,
91+ }
92+ inputs_dict = self .dummy_input
93+ return init_dict , inputs_dict
94+
95+ def test_gradient_checkpointing_is_applied (self ):
96+ expected_set = {"HunyuanVideoTransformer3DModel" }
97+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
98+
99+ @require_torch_gpu
100+ @require_torch_2
101+ @is_torch_compile
102+ @slow
103+ def test_torch_compile_recompilation_and_graph_break (self ):
104+ torch ._dynamo .reset ()
105+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
106+
107+ model = self .model_class (** init_dict ).to (torch_device )
108+ model = torch .compile (model , fullgraph = True )
109+
110+ with torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
111+ _ = model (** inputs_dict )
112+ _ = model (** inputs_dict )
113+
114+
115+ class HunyuanSkyreelsImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
116+ model_class = HunyuanVideoTransformer3DModel
117+ main_input_name = "hidden_states"
118+ uses_custom_attn_processor = True
119+
120+ @property
121+ def dummy_input (self ):
122+ batch_size = 1
123+ num_channels = 8
124+ num_frames = 1
125+ height = 16
126+ width = 16
127+ text_encoder_embedding_dim = 16
128+ pooled_projection_dim = 8
129+ sequence_length = 12
130+
131+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
132+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
133+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
134+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
135+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
136+ guidance = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
137+
138+ return {
139+ "hidden_states" : hidden_states ,
140+ "timestep" : timestep ,
141+ "encoder_hidden_states" : encoder_hidden_states ,
142+ "pooled_projections" : pooled_projections ,
143+ "encoder_attention_mask" : encoder_attention_mask ,
144+ "guidance" : guidance ,
145+ }
146+
147+ @property
148+ def input_shape (self ):
149+ return (8 , 1 , 16 , 16 )
150+
151+ @property
152+ def output_shape (self ):
153+ return (4 , 1 , 16 , 16 )
154+
155+ def prepare_init_args_and_inputs_for_common (self ):
156+ init_dict = {
157+ "in_channels" : 8 ,
158+ "out_channels" : 4 ,
159+ "num_attention_heads" : 2 ,
160+ "attention_head_dim" : 10 ,
161+ "num_layers" : 1 ,
162+ "num_single_layers" : 1 ,
163+ "num_refiner_layers" : 1 ,
164+ "patch_size" : 1 ,
165+ "patch_size_t" : 1 ,
166+ "guidance_embeds" : True ,
167+ "text_embed_dim" : 16 ,
168+ "pooled_projection_dim" : 8 ,
169+ "rope_axes_dim" : (2 , 4 , 4 ),
170+ "image_condition_type" : None ,
171+ }
172+ inputs_dict = self .dummy_input
173+ return init_dict , inputs_dict
174+
175+ def test_output (self ):
176+ super ().test_output (expected_output_shape = (1 , * self .output_shape ))
177+
178+ def test_gradient_checkpointing_is_applied (self ):
179+ expected_set = {"HunyuanVideoTransformer3DModel" }
180+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
181+
182+ @require_torch_gpu
183+ @require_torch_2
184+ @is_torch_compile
185+ @slow
186+ def test_torch_compile_recompilation_and_graph_break (self ):
187+ torch ._dynamo .reset ()
188+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
189+
190+ model = self .model_class (** init_dict ).to (torch_device )
191+ model = torch .compile (model , fullgraph = True )
192+
193+ with torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
194+ _ = model (** inputs_dict )
195+ _ = model (** inputs_dict )
196+
197+
198+ class HunyuanVideoImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
199+ model_class = HunyuanVideoTransformer3DModel
200+ main_input_name = "hidden_states"
201+ uses_custom_attn_processor = True
202+
203+ @property
204+ def dummy_input (self ):
205+ batch_size = 1
206+ num_channels = 2 * 4 + 1
207+ num_frames = 1
208+ height = 16
209+ width = 16
210+ text_encoder_embedding_dim = 16
211+ pooled_projection_dim = 8
212+ sequence_length = 12
213+
214+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
215+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
216+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
217+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
218+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
219+
220+ return {
221+ "hidden_states" : hidden_states ,
222+ "timestep" : timestep ,
223+ "encoder_hidden_states" : encoder_hidden_states ,
224+ "pooled_projections" : pooled_projections ,
225+ "encoder_attention_mask" : encoder_attention_mask ,
226+ }
227+
228+ @property
229+ def input_shape (self ):
230+ return (8 , 1 , 16 , 16 )
231+
232+ @property
233+ def output_shape (self ):
234+ return (4 , 1 , 16 , 16 )
235+
236+ def prepare_init_args_and_inputs_for_common (self ):
237+ init_dict = {
238+ "in_channels" : 2 * 4 + 1 ,
239+ "out_channels" : 4 ,
240+ "num_attention_heads" : 2 ,
241+ "attention_head_dim" : 10 ,
242+ "num_layers" : 1 ,
243+ "num_single_layers" : 1 ,
244+ "num_refiner_layers" : 1 ,
245+ "patch_size" : 1 ,
246+ "patch_size_t" : 1 ,
247+ "guidance_embeds" : False ,
248+ "text_embed_dim" : 16 ,
249+ "pooled_projection_dim" : 8 ,
250+ "rope_axes_dim" : (2 , 4 , 4 ),
251+ "image_condition_type" : "latent_concat" ,
252+ }
253+ inputs_dict = self .dummy_input
254+ return init_dict , inputs_dict
255+
256+ def test_output (self ):
257+ super ().test_output (expected_output_shape = (1 , * self .output_shape ))
258+
259+ def test_gradient_checkpointing_is_applied (self ):
260+ expected_set = {"HunyuanVideoTransformer3DModel" }
261+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
262+
263+ @require_torch_gpu
264+ @require_torch_2
265+ @is_torch_compile
266+ @slow
267+ def test_torch_compile_recompilation_and_graph_break (self ):
268+ torch ._dynamo .reset ()
269+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
270+
271+ model = self .model_class (** init_dict ).to (torch_device )
272+ model = torch .compile (model , fullgraph = True )
273+
274+ with torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
275+ _ = model (** inputs_dict )
276+ _ = model (** inputs_dict )
277+
278+
279+ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
280+ model_class = HunyuanVideoTransformer3DModel
281+ main_input_name = "hidden_states"
282+ uses_custom_attn_processor = True
283+
284+ @property
285+ def dummy_input (self ):
286+ batch_size = 1
287+ num_channels = 2
288+ num_frames = 1
289+ height = 16
290+ width = 16
291+ text_encoder_embedding_dim = 16
292+ pooled_projection_dim = 8
293+ sequence_length = 12
294+
295+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
296+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
297+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
298+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
299+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
300+ guidance = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
301+
302+ return {
303+ "hidden_states" : hidden_states ,
304+ "timestep" : timestep ,
305+ "encoder_hidden_states" : encoder_hidden_states ,
306+ "pooled_projections" : pooled_projections ,
307+ "encoder_attention_mask" : encoder_attention_mask ,
308+ "guidance" : guidance ,
309+ }
310+
311+ @property
312+ def input_shape (self ):
313+ return (8 , 1 , 16 , 16 )
314+
315+ @property
316+ def output_shape (self ):
317+ return (4 , 1 , 16 , 16 )
318+
319+ def prepare_init_args_and_inputs_for_common (self ):
320+ init_dict = {
321+ "in_channels" : 2 ,
322+ "out_channels" : 4 ,
323+ "num_attention_heads" : 2 ,
324+ "attention_head_dim" : 10 ,
325+ "num_layers" : 1 ,
326+ "num_single_layers" : 1 ,
327+ "num_refiner_layers" : 1 ,
328+ "patch_size" : 1 ,
329+ "patch_size_t" : 1 ,
330+ "guidance_embeds" : True ,
331+ "text_embed_dim" : 16 ,
332+ "pooled_projection_dim" : 8 ,
333+ "rope_axes_dim" : (2 , 4 , 4 ),
334+ "image_condition_type" : "token_replace" ,
335+ }
336+ inputs_dict = self .dummy_input
337+ return init_dict , inputs_dict
338+
339+ def test_output (self ):
340+ super ().test_output (expected_output_shape = (1 , * self .output_shape ))
341+
342+ def test_gradient_checkpointing_is_applied (self ):
343+ expected_set = {"HunyuanVideoTransformer3DModel" }
344+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments