@@ -87,3 +87,70 @@ def prepare_init_args_and_inputs_for_common(self):
8787 def test_gradient_checkpointing_is_applied (self ):
8888 expected_set = {"HunyuanVideoTransformer3DModel" }
8989 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
90+
91+
92+ class HunyuanSkyreelsImageToVideoTransformer3DTests (ModelTesterMixin , unittest .TestCase ):
93+ model_class = HunyuanVideoTransformer3DModel
94+ main_input_name = "hidden_states"
95+ uses_custom_attn_processor = True
96+
97+ @property
98+ def dummy_input (self ):
99+ batch_size = 1
100+ num_channels = 8
101+ num_frames = 1
102+ height = 16
103+ width = 16
104+ text_encoder_embedding_dim = 16
105+ pooled_projection_dim = 8
106+ sequence_length = 12
107+
108+ hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
109+ timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
110+ encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_encoder_embedding_dim )).to (torch_device )
111+ pooled_projections = torch .randn ((batch_size , pooled_projection_dim )).to (torch_device )
112+ encoder_attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
113+ guidance = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device , dtype = torch .float32 )
114+
115+ return {
116+ "hidden_states" : hidden_states ,
117+ "timestep" : timestep ,
118+ "encoder_hidden_states" : encoder_hidden_states ,
119+ "pooled_projections" : pooled_projections ,
120+ "encoder_attention_mask" : encoder_attention_mask ,
121+ "guidance" : guidance ,
122+ }
123+
124+ @property
125+ def input_shape (self ):
126+ return (8 , 1 , 16 , 16 )
127+
128+ @property
129+ def output_shape (self ):
130+ return (4 , 1 , 16 , 16 )
131+
132+ def prepare_init_args_and_inputs_for_common (self ):
133+ init_dict = {
134+ "in_channels" : 8 ,
135+ "out_channels" : 4 ,
136+ "num_attention_heads" : 2 ,
137+ "attention_head_dim" : 10 ,
138+ "num_layers" : 1 ,
139+ "num_single_layers" : 1 ,
140+ "num_refiner_layers" : 1 ,
141+ "patch_size" : 1 ,
142+ "patch_size_t" : 1 ,
143+ "guidance_embeds" : True ,
144+ "text_embed_dim" : 16 ,
145+ "pooled_projection_dim" : 8 ,
146+ "rope_axes_dim" : (2 , 4 , 4 ),
147+ }
148+ inputs_dict = self .dummy_input
149+ return init_dict , inputs_dict
150+
151+ def test_output (self ):
152+ super ().test_output (expected_output_shape = (1 , * self .output_shape ))
153+
154+ def test_gradient_checkpointing_is_applied (self ):
155+ expected_set = {"HunyuanVideoTransformer3DModel" }
156+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments