@@ -80,6 +80,7 @@ def prepare_init_args_and_inputs_for_common(self):
8080            "text_embed_dim" : 16 ,
8181            "pooled_projection_dim" : 8 ,
8282            "rope_axes_dim" : (2 , 4 , 4 ),
83+             "image_condition_type" : None ,
8384        }
8485        inputs_dict  =  self .dummy_input 
8586        return  init_dict , inputs_dict 
@@ -144,6 +145,7 @@ def prepare_init_args_and_inputs_for_common(self):
144145            "text_embed_dim" : 16 ,
145146            "pooled_projection_dim" : 8 ,
146147            "rope_axes_dim" : (2 , 4 , 4 ),
148+             "image_condition_type" : None ,
147149        }
148150        inputs_dict  =  self .dummy_input 
149151        return  init_dict , inputs_dict 
@@ -209,6 +211,75 @@ def prepare_init_args_and_inputs_for_common(self):
209211            "text_embed_dim" : 16 ,
210212            "pooled_projection_dim" : 8 ,
211213            "rope_axes_dim" : (2 , 4 , 4 ),
214+             "image_condition_type" : "latent_concat" ,
215+         }
216+         inputs_dict  =  self .dummy_input 
217+         return  init_dict , inputs_dict 
218+ 
219+     def  test_output (self ):
220+         super ().test_output (expected_output_shape = (1 , * self .output_shape ))
221+ 
222+     def  test_gradient_checkpointing_is_applied (self ):
223+         expected_set  =  {"HunyuanVideoTransformer3DModel" }
224+         super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
225+ 
226+ 
227+ class  HunyuanVideoTokenReplaceImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
228+     model_class  =  HunyuanVideoTransformer3DModel 
229+     main_input_name  =  "hidden_states" 
230+     uses_custom_attn_processor  =  True 
231+ 
232+     @property  
233+     def  dummy_input (self ):
234+         batch_size  =  1 
235+         num_channels  =  2 
236+         num_frames  =  1 
237+         height  =  16 
238+         width  =  16 
239+         text_encoder_embedding_dim  =  16 
240+         pooled_projection_dim  =  8 
241+         sequence_length  =  12 
242+ 
243+         hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
244+         timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
245+         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
246+         pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
247+         encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
248+         guidance  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
249+ 
250+         return  {
251+             "hidden_states" : hidden_states ,
252+             "timestep" : timestep ,
253+             "encoder_hidden_states" : encoder_hidden_states ,
254+             "pooled_projections" : pooled_projections ,
255+             "encoder_attention_mask" : encoder_attention_mask ,
256+             "guidance" : guidance ,
257+         }
258+ 
259+     @property  
260+     def  input_shape (self ):
261+         return  (8 , 1 , 16 , 16 )
262+ 
263+     @property  
264+     def  output_shape (self ):
265+         return  (4 , 1 , 16 , 16 )
266+ 
267+     def  prepare_init_args_and_inputs_for_common (self ):
268+         init_dict  =  {
269+             "in_channels" : 2 ,
270+             "out_channels" : 4 ,
271+             "num_attention_heads" : 2 ,
272+             "attention_head_dim" : 10 ,
273+             "num_layers" : 1 ,
274+             "num_single_layers" : 1 ,
275+             "num_refiner_layers" : 1 ,
276+             "patch_size" : 1 ,
277+             "patch_size_t" : 1 ,
278+             "guidance_embeds" : True ,
279+             "text_embed_dim" : 16 ,
280+             "pooled_projection_dim" : 8 ,
281+             "rope_axes_dim" : (2 , 4 , 4 ),
282+             "image_condition_type" : "token_replace" ,
212283        }
213284        inputs_dict  =  self .dummy_input 
214285        return  init_dict , inputs_dict 
0 commit comments