|
4 | 4 |
|
5 | 5 | class TestGQA: |
6 | 6 | def test_gqa_broadcasting(self): |
7 | | - # 1. Define Shapes |
8 | 7 | B, T, S = 2, 4, 5 |
9 | 8 | D = 8 |
10 | 9 |
|
11 | 10 | # GQA Config: Query=6 heads, Key/Value=3 heads (Ratio=2) |
12 | 11 | num_heads_q = 6 |
13 | 12 | num_heads_kv = 3 |
14 | 13 |
|
15 | | - # 2. Create Inputs |
16 | 14 | k1, k2, k3 = jax.random.split(jax.random.key(0), 3) |
17 | 15 | query = jax.random.normal(k1, (B, T, num_heads_q, D)) |
18 | 16 | key = jax.random.normal(k2, (B, S, num_heads_kv, D)) |
19 | 17 | value = jax.random.normal(k3, (B, S, num_heads_kv, D)) |
20 | 18 |
|
21 | | - # 3. Run Attention (Should not crash) |
22 | 19 | output = nnx.dot_product_attention(query, key, value) |
23 | 20 |
|
24 | | - # 4. Verify Output Shape matches Query heads (6), not Key heads (3) |
25 | 21 | assert output.shape == (B, T, num_heads_q, D) |
26 | 22 |
|
27 | 23 | def test_gqa_invalid_heads(self): |
28 | | - # Test that it raises an error if heads aren't divisible |
29 | 24 | B, T, D = 1, 4, 8 |
30 | | - query = jnp.ones((B, T, 5, D)) # 5 heads |
31 | | - key = jnp.ones((B, T, 2, D)) # 2 heads (5 is not divisible by 2) |
| 25 | + query = jnp.ones((B, T, 5, D)) |
| 26 | + key = jnp.ones((B, T, 2, D)) |
32 | 27 | value = key |
33 | 28 |
|
34 | 29 | try: |
35 | 30 | nnx.dot_product_attention(query, key, value) |
36 | 31 | assert False, "Should have raised ValueError" |
37 | 32 | except ValueError as e: |
38 | | - # Adjusted to match the actual error message in attention.py |
39 | 33 | assert "must be multiple" in str(e) |
0 commit comments