@@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self):
4343 "down_block_types" : (
4444 "HunyuanVideoDownBlock3D" ,
4545 "HunyuanVideoDownBlock3D" ,
46+ "HunyuanVideoDownBlock3D" ,
47+ "HunyuanVideoDownBlock3D" ,
4648 ),
4749 "up_block_types" : (
4850 "HunyuanVideoUpBlock3D" ,
4951 "HunyuanVideoUpBlock3D" ,
52+ "HunyuanVideoUpBlock3D" ,
53+ "HunyuanVideoUpBlock3D" ,
5054 ),
5155 "block_out_channels" : (8 , 8 , 8 , 8 ),
5256 "layers_per_block" : 1 ,
@@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self):
154158 }
155159 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
156160
161+ # We need to overwrite this test because the base test does not account length of down_block_types
162+ def test_forward_with_norm_groups (self ):
163+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
164+
165+ init_dict ["norm_num_groups" ] = 16
166+ init_dict ["block_out_channels" ] = (16 , 16 , 16 , 16 )
167+
168+ model = self .model_class (** init_dict )
169+ model .to (torch_device )
170+ model .eval ()
171+
172+ with torch .no_grad ():
173+ output = model (** inputs_dict )
174+
175+ if isinstance (output , dict ):
176+ output = output .to_tuple ()[0 ]
177+
178+ self .assertIsNotNone (output )
179+ expected_shape = inputs_dict ["sample" ].shape
180+ self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
181+
157182 @unittest .skip ("Unsupported test." )
158183 def test_outputs_equivalence (self ):
159184 pass
0 commit comments