@@ -205,12 +205,14 @@ def test_layered_model_with_mask(self):
205205 init_dict = {
206206 "patch_size" : 2 ,
207207 "in_channels" : 16 ,
208- "out_channels" : 16 ,
208+ "out_channels" : 4 ,
209209 "num_layers" : 2 ,
210- "attention_head_dim" : 128 ,
211- "num_attention_heads" : 4 ,
210+ "attention_head_dim" : 16 ,
211+ "num_attention_heads" : 3 ,
212212 "joint_attention_dim" : 16 ,
213+ "axes_dims_rope" : (8 , 4 , 4 ), # Must match attention_head_dim (8+4+4=16)
213214 "use_layer3d_rope" : True , # Enable layered RoPE
215+ "use_additional_t_cond" : True , # Enable additional time conditioning
214216 }
215217
216218 model = self .model_class (** init_dict ).to (torch_device )
@@ -236,6 +238,9 @@ def test_layered_model_with_mask(self):
236238
237239 timestep = torch .tensor ([1.0 ]).to (torch_device )
238240
241+ # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
242+ addition_t_cond = torch .tensor ([0 ], dtype = torch .long ).to (torch_device )
243+
239244 # Layer structure: 4 layers + 1 condition image
240245 img_shapes = [
241246 [
@@ -254,6 +259,7 @@ def test_layered_model_with_mask(self):
254259 encoder_hidden_states_mask = encoder_hidden_states_mask ,
255260 timestep = timestep ,
256261 img_shapes = img_shapes ,
262+ additional_t_cond = addition_t_cond ,
257263 )
258264
259265 self .assertEqual (output .sample .shape [1 ], hidden_states .shape [1 ])
0 commit comments