Skip to content

Commit e13231c

Browse files
committed
add model tests
1 parent e978876 commit e13231c

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,68 @@ def test_output(self):
154154
def test_gradient_checkpointing_is_applied(self):
155155
expected_set = {"HunyuanVideoTransformer3DModel"}
156156
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
157+
158+
159+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
160+
model_class = HunyuanVideoTransformer3DModel
161+
main_input_name = "hidden_states"
162+
uses_custom_attn_processor = True
163+
164+
@property
165+
def dummy_input(self):
166+
batch_size = 1
167+
num_channels = 2 * 4 + 1
168+
num_frames = 1
169+
height = 16
170+
width = 16
171+
text_encoder_embedding_dim = 16
172+
pooled_projection_dim = 8
173+
sequence_length = 12
174+
175+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
176+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
177+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
178+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
179+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
180+
181+
return {
182+
"hidden_states": hidden_states,
183+
"timestep": timestep,
184+
"encoder_hidden_states": encoder_hidden_states,
185+
"pooled_projections": pooled_projections,
186+
"encoder_attention_mask": encoder_attention_mask,
187+
}
188+
189+
@property
190+
def input_shape(self):
191+
return (8, 1, 16, 16)
192+
193+
@property
194+
def output_shape(self):
195+
return (4, 1, 16, 16)
196+
197+
def prepare_init_args_and_inputs_for_common(self):
198+
init_dict = {
199+
"in_channels": 2 * 4 + 1,
200+
"out_channels": 4,
201+
"num_attention_heads": 2,
202+
"attention_head_dim": 10,
203+
"num_layers": 1,
204+
"num_single_layers": 1,
205+
"num_refiner_layers": 1,
206+
"patch_size": 1,
207+
"patch_size_t": 1,
208+
"guidance_embeds": False,
209+
"text_embed_dim": 16,
210+
"pooled_projection_dim": 8,
211+
"rope_axes_dim": (2, 4, 4),
212+
}
213+
inputs_dict = self.dummy_input
214+
return init_dict, inputs_dict
215+
216+
def test_output(self):
217+
super().test_output(expected_output_shape=(1, *self.output_shape))
218+
219+
def test_gradient_checkpointing_is_applied(self):
220+
expected_set = {"HunyuanVideoTransformer3DModel"}
221+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)