1616
1717import  torch 
1818
19- from  diffusers  import  HunyuanVideoTransformer3DModel 
19+ from  diffusers  import  HunyuanVideoFramepackTransformer3DModel 
2020from  diffusers .utils .testing_utils  import  (
2121    enable_full_determinism ,
22-     is_torch_compile ,
23-     require_torch_2 ,
24-     require_torch_gpu ,
25-     slow ,
2622    torch_device ,
2723)
2824
3329
3430
3531class  HunyuanVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
36-     model_class  =  HunyuanVideoTransformer3DModel 
32+     model_class  =  HunyuanVideoFramepackTransformer3DModel 
3733    main_input_name  =  "hidden_states" 
3834    uses_custom_attn_processor  =  True 
35+     model_split_percents  =  [0.5 , 0.7 , 0.9 ]
3936
4037    @property  
4138    def  dummy_input (self ):
4239        batch_size  =  1 
4340        num_channels  =  4 
44-         num_frames  =  1 
45-         height  =  16 
46-         width  =  16 
41+         num_frames  =  3 
42+         height  =  4 
43+         width  =  4 
4744        text_encoder_embedding_dim  =  16 
45+         image_encoder_embedding_dim  =  16 
4846        pooled_projection_dim  =  8 
4947        sequence_length  =  12 
5048
5149        hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
52-         timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
5350        encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
5451        pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
5552        encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
56-         guidance  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
57- 
58-         return  {
59-             "hidden_states" : hidden_states ,
60-             "timestep" : timestep ,
61-             "encoder_hidden_states" : encoder_hidden_states ,
62-             "pooled_projections" : pooled_projections ,
63-             "encoder_attention_mask" : encoder_attention_mask ,
64-             "guidance" : guidance ,
65-         }
66- 
67-     @property  
68-     def  input_shape (self ):
69-         return  (4 , 1 , 16 , 16 )
70- 
71-     @property  
72-     def  output_shape (self ):
73-         return  (4 , 1 , 16 , 16 )
74- 
75-     def  prepare_init_args_and_inputs_for_common (self ):
76-         init_dict  =  {
77-             "in_channels" : 4 ,
78-             "out_channels" : 4 ,
79-             "num_attention_heads" : 2 ,
80-             "attention_head_dim" : 10 ,
81-             "num_layers" : 1 ,
82-             "num_single_layers" : 1 ,
83-             "num_refiner_layers" : 1 ,
84-             "patch_size" : 1 ,
85-             "patch_size_t" : 1 ,
86-             "guidance_embeds" : True ,
87-             "text_embed_dim" : 16 ,
88-             "pooled_projection_dim" : 8 ,
89-             "rope_axes_dim" : (2 , 4 , 4 ),
90-             "image_condition_type" : None ,
91-         }
92-         inputs_dict  =  self .dummy_input 
93-         return  init_dict , inputs_dict 
94- 
95-     def  test_gradient_checkpointing_is_applied (self ):
96-         expected_set  =  {"HunyuanVideoTransformer3DModel" }
97-         super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
98- 
99-     @require_torch_gpu  
100-     @require_torch_2  
101-     @is_torch_compile  
102-     @slow  
103-     def  test_torch_compile_recompilation_and_graph_break (self ):
104-         torch ._dynamo .reset ()
105-         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
106- 
107-         model  =  self .model_class (** init_dict ).to (torch_device )
108-         model  =  torch .compile (model , fullgraph = True )
109- 
110-         with  torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
111-             _  =  model (** inputs_dict )
112-             _  =  model (** inputs_dict )
113- 
114- 
115- class  HunyuanSkyreelsImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
116-     model_class  =  HunyuanVideoTransformer3DModel 
117-     main_input_name  =  "hidden_states" 
118-     uses_custom_attn_processor  =  True 
119- 
120-     @property  
121-     def  dummy_input (self ):
122-         batch_size  =  1 
123-         num_channels  =  8 
124-         num_frames  =  1 
125-         height  =  16 
126-         width  =  16 
127-         text_encoder_embedding_dim  =  16 
128-         pooled_projection_dim  =  8 
129-         sequence_length  =  12 
130- 
131-         hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
53+         image_embeds  =  torch .randn ((batch_size , sequence_length , image_encoder_embedding_dim )).to (torch_device )
54+         indices_latents  =  torch .ones ((3 ,)).to (torch_device )
55+         latents_clean  =  torch .randn ((batch_size , num_channels , num_frames  -  1 , height , width )).to (torch_device )
56+         indices_latents_clean  =  torch .ones ((num_frames  -  1 ,)).to (torch_device )
57+         latents_history_2x  =  torch .randn ((batch_size , num_channels , num_frames  -  1 , height , width )).to (torch_device )
58+         indices_latents_history_2x  =  torch .ones ((num_frames  -  1 ,)).to (torch_device )
59+         latents_history_4x  =  torch .randn ((batch_size , num_channels , (num_frames  -  1 ) *  4 , height , width )).to (
60+             torch_device 
61+         )
62+         indices_latents_history_4x  =  torch .ones (((num_frames  -  1 ) *  4 ,)).to (torch_device )
13263        timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
133-         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
134-         pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
135-         encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
136-         guidance  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
64+         guidance  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
13765
13866        return  {
13967            "hidden_states" : hidden_states ,
@@ -142,203 +70,47 @@ def dummy_input(self):
14270            "pooled_projections" : pooled_projections ,
14371            "encoder_attention_mask" : encoder_attention_mask ,
14472            "guidance" : guidance ,
73+             "image_embeds" : image_embeds ,
74+             "indices_latents" : indices_latents ,
75+             "latents_clean" : latents_clean ,
76+             "indices_latents_clean" : indices_latents_clean ,
77+             "latents_history_2x" : latents_history_2x ,
78+             "indices_latents_history_2x" : indices_latents_history_2x ,
79+             "latents_history_4x" : latents_history_4x ,
80+             "indices_latents_history_4x" : indices_latents_history_4x ,
14581        }
14682
14783    @property  
14884    def  input_shape (self ):
149-         return  (8 ,  1 ,  16 ,  16 )
85+         return  (4 ,  3 ,  4 ,  4 )
15086
15187    @property  
15288    def  output_shape (self ):
153-         return  (4 , 1 ,  16 ,  16 )
89+         return  (4 , 3 ,  4 ,  4 )
15490
15591    def  prepare_init_args_and_inputs_for_common (self ):
15692        init_dict  =  {
157-             "in_channels" : 8 ,
93+             "in_channels" : 4 ,
15894            "out_channels" : 4 ,
15995            "num_attention_heads" : 2 ,
16096            "attention_head_dim" : 10 ,
16197            "num_layers" : 1 ,
16298            "num_single_layers" : 1 ,
16399            "num_refiner_layers" : 1 ,
164-             "patch_size" : 1 ,
100+             "patch_size" : 2 ,
165101            "patch_size_t" : 1 ,
166102            "guidance_embeds" : True ,
167103            "text_embed_dim" : 16 ,
168104            "pooled_projection_dim" : 8 ,
169105            "rope_axes_dim" : (2 , 4 , 4 ),
170106            "image_condition_type" : None ,
107+             "has_image_proj" : True ,
108+             "image_proj_dim" : 16 ,
109+             "has_clean_x_embedder" : True ,
171110        }
172111        inputs_dict  =  self .dummy_input 
173112        return  init_dict , inputs_dict 
174113
175-     def  test_output (self ):
176-         super ().test_output (expected_output_shape = (1 , * self .output_shape ))
177- 
178-     def  test_gradient_checkpointing_is_applied (self ):
179-         expected_set  =  {"HunyuanVideoTransformer3DModel" }
180-         super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
181- 
182-     @require_torch_gpu  
183-     @require_torch_2  
184-     @is_torch_compile  
185-     @slow  
186-     def  test_torch_compile_recompilation_and_graph_break (self ):
187-         torch ._dynamo .reset ()
188-         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
189- 
190-         model  =  self .model_class (** init_dict ).to (torch_device )
191-         model  =  torch .compile (model , fullgraph = True )
192- 
193-         with  torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
194-             _  =  model (** inputs_dict )
195-             _  =  model (** inputs_dict )
196- 
197- 
198- class  HunyuanVideoImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
199-     model_class  =  HunyuanVideoTransformer3DModel 
200-     main_input_name  =  "hidden_states" 
201-     uses_custom_attn_processor  =  True 
202- 
203-     @property  
204-     def  dummy_input (self ):
205-         batch_size  =  1 
206-         num_channels  =  2  *  4  +  1 
207-         num_frames  =  1 
208-         height  =  16 
209-         width  =  16 
210-         text_encoder_embedding_dim  =  16 
211-         pooled_projection_dim  =  8 
212-         sequence_length  =  12 
213- 
214-         hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
215-         timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
216-         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
217-         pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
218-         encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
219- 
220-         return  {
221-             "hidden_states" : hidden_states ,
222-             "timestep" : timestep ,
223-             "encoder_hidden_states" : encoder_hidden_states ,
224-             "pooled_projections" : pooled_projections ,
225-             "encoder_attention_mask" : encoder_attention_mask ,
226-         }
227- 
228-     @property  
229-     def  input_shape (self ):
230-         return  (8 , 1 , 16 , 16 )
231- 
232-     @property  
233-     def  output_shape (self ):
234-         return  (4 , 1 , 16 , 16 )
235- 
236-     def  prepare_init_args_and_inputs_for_common (self ):
237-         init_dict  =  {
238-             "in_channels" : 2  *  4  +  1 ,
239-             "out_channels" : 4 ,
240-             "num_attention_heads" : 2 ,
241-             "attention_head_dim" : 10 ,
242-             "num_layers" : 1 ,
243-             "num_single_layers" : 1 ,
244-             "num_refiner_layers" : 1 ,
245-             "patch_size" : 1 ,
246-             "patch_size_t" : 1 ,
247-             "guidance_embeds" : False ,
248-             "text_embed_dim" : 16 ,
249-             "pooled_projection_dim" : 8 ,
250-             "rope_axes_dim" : (2 , 4 , 4 ),
251-             "image_condition_type" : "latent_concat" ,
252-         }
253-         inputs_dict  =  self .dummy_input 
254-         return  init_dict , inputs_dict 
255- 
256-     def  test_output (self ):
257-         super ().test_output (expected_output_shape = (1 , * self .output_shape ))
258- 
259114    def  test_gradient_checkpointing_is_applied (self ):
260-         expected_set  =  {"HunyuanVideoTransformer3DModel " }
115+         expected_set  =  {"HunyuanVideoFramepackTransformer3DModel " }
261116        super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
262- 
263-     @require_torch_gpu  
264-     @require_torch_2  
265-     @is_torch_compile  
266-     @slow  
267-     def  test_torch_compile_recompilation_and_graph_break (self ):
268-         torch ._dynamo .reset ()
269-         init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
270- 
271-         model  =  self .model_class (** init_dict ).to (torch_device )
272-         model  =  torch .compile (model , fullgraph = True )
273- 
274-         with  torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
275-             _  =  model (** inputs_dict )
276-             _  =  model (** inputs_dict )
277- 
278- 
279- class  HunyuanVideoTokenReplaceImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
280-     model_class  =  HunyuanVideoTransformer3DModel 
281-     main_input_name  =  "hidden_states" 
282-     uses_custom_attn_processor  =  True 
283- 
284-     @property  
285-     def  dummy_input (self ):
286-         batch_size  =  1 
287-         num_channels  =  2 
288-         num_frames  =  1 
289-         height  =  16 
290-         width  =  16 
291-         text_encoder_embedding_dim  =  16 
292-         pooled_projection_dim  =  8 
293-         sequence_length  =  12 
294- 
295-         hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
296-         timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
297-         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
298-         pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
299-         encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
300-         guidance  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
301- 
302-         return  {
303-             "hidden_states" : hidden_states ,
304-             "timestep" : timestep ,
305-             "encoder_hidden_states" : encoder_hidden_states ,
306-             "pooled_projections" : pooled_projections ,
307-             "encoder_attention_mask" : encoder_attention_mask ,
308-             "guidance" : guidance ,
309-         }
310- 
311-     @property  
312-     def  input_shape (self ):
313-         return  (8 , 1 , 16 , 16 )
314- 
315-     @property  
316-     def  output_shape (self ):
317-         return  (4 , 1 , 16 , 16 )
318- 
319-     def  prepare_init_args_and_inputs_for_common (self ):
320-         init_dict  =  {
321-             "in_channels" : 2 ,
322-             "out_channels" : 4 ,
323-             "num_attention_heads" : 2 ,
324-             "attention_head_dim" : 10 ,
325-             "num_layers" : 1 ,
326-             "num_single_layers" : 1 ,
327-             "num_refiner_layers" : 1 ,
328-             "patch_size" : 1 ,
329-             "patch_size_t" : 1 ,
330-             "guidance_embeds" : True ,
331-             "text_embed_dim" : 16 ,
332-             "pooled_projection_dim" : 8 ,
333-             "rope_axes_dim" : (2 , 4 , 4 ),
334-             "image_condition_type" : "token_replace" ,
335-         }
336-         inputs_dict  =  self .dummy_input 
337-         return  init_dict , inputs_dict 
338- 
339-     def  test_output (self ):
340-         super ().test_output (expected_output_shape = (1 , * self .output_shape ))
341- 
342-     def  test_gradient_checkpointing_is_applied (self ):
343-         expected_set  =  {"HunyuanVideoTransformer3DModel" }
344-         super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments