@@ -154,3 +154,68 @@ def test_output(self):
154154    def  test_gradient_checkpointing_is_applied (self ):
155155        expected_set  =  {"HunyuanVideoTransformer3DModel" }
156156        super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
157+ 
158+ 
159+ class  HunyuanVideoImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
160+     model_class  =  HunyuanVideoTransformer3DModel 
161+     main_input_name  =  "hidden_states" 
162+     uses_custom_attn_processor  =  True 
163+ 
164+     @property  
165+     def  dummy_input (self ):
166+         batch_size  =  1 
167+         num_channels  =  2  *  4  +  1 
168+         num_frames  =  1 
169+         height  =  16 
170+         width  =  16 
171+         text_encoder_embedding_dim  =  16 
172+         pooled_projection_dim  =  8 
173+         sequence_length  =  12 
174+ 
175+         hidden_states  =  torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
176+         timestep  =  torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
177+         encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
178+         pooled_projections  =  torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
179+         encoder_attention_mask  =  torch .ones ((batch_size , sequence_length )).to (torch_device )
180+ 
181+         return  {
182+             "hidden_states" : hidden_states ,
183+             "timestep" : timestep ,
184+             "encoder_hidden_states" : encoder_hidden_states ,
185+             "pooled_projections" : pooled_projections ,
186+             "encoder_attention_mask" : encoder_attention_mask ,
187+         }
188+ 
189+     @property  
190+     def  input_shape (self ):
191+         return  (8 , 1 , 16 , 16 )
192+ 
193+     @property  
194+     def  output_shape (self ):
195+         return  (4 , 1 , 16 , 16 )
196+ 
197+     def  prepare_init_args_and_inputs_for_common (self ):
198+         init_dict  =  {
199+             "in_channels" : 2  *  4  +  1 ,
200+             "out_channels" : 4 ,
201+             "num_attention_heads" : 2 ,
202+             "attention_head_dim" : 10 ,
203+             "num_layers" : 1 ,
204+             "num_single_layers" : 1 ,
205+             "num_refiner_layers" : 1 ,
206+             "patch_size" : 1 ,
207+             "patch_size_t" : 1 ,
208+             "guidance_embeds" : False ,
209+             "text_embed_dim" : 16 ,
210+             "pooled_projection_dim" : 8 ,
211+             "rope_axes_dim" : (2 , 4 , 4 ),
212+         }
213+         inputs_dict  =  self .dummy_input 
214+         return  init_dict , inputs_dict 
215+ 
216+     def  test_output (self ):
217+         super ().test_output (expected_output_shape = (1 , * self .output_shape ))
218+ 
219+     def  test_gradient_checkpointing_is_applied (self ):
220+         expected_set  =  {"HunyuanVideoTransformer3DModel" }
221+         super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments