Skip to content

Commit 1be67d4

Browse files
committed
fix test
1 parent 02c3a19 commit 1be67d4

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ def get_autoencoder_kl_hunyuan_video_config(self):
4343
"down_block_types": (
4444
"HunyuanVideoDownBlock3D",
4545
"HunyuanVideoDownBlock3D",
46+
"HunyuanVideoDownBlock3D",
47+
"HunyuanVideoDownBlock3D",
4648
),
4749
"up_block_types": (
4850
"HunyuanVideoUpBlock3D",
4951
"HunyuanVideoUpBlock3D",
52+
"HunyuanVideoUpBlock3D",
53+
"HunyuanVideoUpBlock3D",
5054
),
5155
"block_out_channels": (8, 8, 8, 8),
5256
"layers_per_block": 1,
@@ -154,6 +158,27 @@ def test_gradient_checkpointing_is_applied(self):
154158
}
155159
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
156160

161+
# We need to overwrite this test because the base test does not account length of down_block_types
162+
def test_forward_with_norm_groups(self):
163+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
164+
165+
init_dict["norm_num_groups"] = 16
166+
init_dict["block_out_channels"] = (16, 16, 16, 16)
167+
168+
model = self.model_class(**init_dict)
169+
model.to(torch_device)
170+
model.eval()
171+
172+
with torch.no_grad():
173+
output = model(**inputs_dict)
174+
175+
if isinstance(output, dict):
176+
output = output.to_tuple()[0]
177+
178+
self.assertIsNotNone(output)
179+
expected_shape = inputs_dict["sample"].shape
180+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
181+
157182
@unittest.skip("Unsupported test.")
158183
def test_outputs_equivalence(self):
159184
pass

0 commit comments

Comments
 (0)