Skip to content

Commit 83c08e4

Browse files
committed
tests
1 parent 3613b23 commit 83c08e4

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,70 @@ def prepare_init_args_and_inputs_for_common(self):
8787
def test_gradient_checkpointing_is_applied(self):
8888
expected_set = {"HunyuanVideoTransformer3DModel"}
8989
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90+
91+
92+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
93+
model_class = HunyuanVideoTransformer3DModel
94+
main_input_name = "hidden_states"
95+
uses_custom_attn_processor = True
96+
97+
@property
98+
def dummy_input(self):
99+
batch_size = 1
100+
num_channels = 8
101+
num_frames = 1
102+
height = 16
103+
width = 16
104+
text_encoder_embedding_dim = 16
105+
pooled_projection_dim = 8
106+
sequence_length = 12
107+
108+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
109+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
110+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
111+
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
112+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
113+
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
114+
115+
return {
116+
"hidden_states": hidden_states,
117+
"timestep": timestep,
118+
"encoder_hidden_states": encoder_hidden_states,
119+
"pooled_projections": pooled_projections,
120+
"encoder_attention_mask": encoder_attention_mask,
121+
"guidance": guidance,
122+
}
123+
124+
@property
125+
def input_shape(self):
126+
return (8, 1, 16, 16)
127+
128+
@property
129+
def output_shape(self):
130+
return (4, 1, 16, 16)
131+
132+
def prepare_init_args_and_inputs_for_common(self):
133+
init_dict = {
134+
"in_channels": 8,
135+
"out_channels": 4,
136+
"num_attention_heads": 2,
137+
"attention_head_dim": 10,
138+
"num_layers": 1,
139+
"num_single_layers": 1,
140+
"num_refiner_layers": 1,
141+
"patch_size": 1,
142+
"patch_size_t": 1,
143+
"guidance_embeds": True,
144+
"text_embed_dim": 16,
145+
"pooled_projection_dim": 8,
146+
"rope_axes_dim": (2, 4, 4),
147+
}
148+
inputs_dict = self.dummy_input
149+
return init_dict, inputs_dict
150+
151+
def test_output(self):
152+
super().test_output(expected_output_shape=(1, *self.output_shape))
153+
154+
def test_gradient_checkpointing_is_applied(self):
155+
expected_set = {"HunyuanVideoTransformer3DModel"}
156+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)