Skip to content

Commit f706457

Browse files
committed
Further adjust test
1 parent 5171e23 commit f706457

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ def capture_pattern_hook(tensor, hook):
318318
len(pattern_tensor.shape) == 4
319319
), f"Pattern tensor should be 4D, got {len(pattern_tensor.shape)}D"
320320

321-
n_heads_dim, pos_q_dim, pos_k_dim = pattern_tensor.shape
321+
batch_dim, n_heads_dim, pos_q_dim, pos_k_dim = pattern_tensor.shape
322+
323+
# Verify the batch dimension is 1
324+
assert batch_dim == 1, f"Batch dimension should be 1, got {batch_dim}"
322325

323326
# Verify dimensions make sense
324327
assert (

0 commit comments

Comments
 (0)