Skip to content

Commit fc93747

Browse files
committed
smaller values
1 parent 8de799c commit fc93747

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,14 @@ def test_layered_model_with_mask(self):
205205
init_dict = {
206206
"patch_size": 2,
207207
"in_channels": 16,
208-
"out_channels": 16,
208+
"out_channels": 4,
209209
"num_layers": 2,
210-
"attention_head_dim": 128,
211-
"num_attention_heads": 4,
210+
"attention_head_dim": 16,
211+
"num_attention_heads": 3,
212212
"joint_attention_dim": 16,
213+
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
213214
"use_layer3d_rope": True, # Enable layered RoPE
215+
"use_additional_t_cond": True, # Enable additional time conditioning
214216
}
215217

216218
model = self.model_class(**init_dict).to(torch_device)
@@ -236,6 +238,9 @@ def test_layered_model_with_mask(self):
236238

237239
timestep = torch.tensor([1.0]).to(torch_device)
238240

241+
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
242+
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
243+
239244
# Layer structure: 4 layers + 1 condition image
240245
img_shapes = [
241246
[
@@ -254,6 +259,7 @@ def test_layered_model_with_mask(self):
254259
encoder_hidden_states_mask=encoder_hidden_states_mask,
255260
timestep=timestep,
256261
img_shapes=img_shapes,
262+
additional_t_cond=addition_t_cond,
257263
)
258264

259265
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])

0 commit comments

Comments
 (0)