Skip to content

Commit 35eaf6f

Browse files
committed
style: clean up comments
1 parent 819c390 commit 35eaf6f

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

flax/nnx/nn/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def dot_product_attention(
239239
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
240240
dtype = query.dtype
241241

242-
# GQA: Broadcast value heads to match query heads if needed.
243-
# 1. Handle Key Broadcasting
242+
# broadcast value heads to match query heads if needed.
243+
# handle key broadcasting
244244
if query.ndim == key.ndim and query.shape[-2] != key.shape[-2]:
245245
q_heads = query.shape[-2]
246246
k_heads = key.shape[-2]
@@ -249,7 +249,7 @@ def dot_product_attention(
249249
n_rep = q_heads // k_heads
250250
key = jnp.repeat(key, n_rep, axis=-2)
251251

252-
# 2. Handle Value Broadcasting
252+
# handle value broadcasting
253253
if query.ndim == value.ndim and query.shape[-2] != value.shape[-2]:
254254
q_heads = query.shape[-2]
255255
v_heads = value.shape[-2]

tests/nnx/nn/gqa_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,30 @@
44

55
class TestGQA:
66
def test_gqa_broadcasting(self):
7-
# 1. Define Shapes
87
B, T, S = 2, 4, 5
98
D = 8
109

1110
# GQA Config: Query=6 heads, Key/Value=3 heads (Ratio=2)
1211
num_heads_q = 6
1312
num_heads_kv = 3
1413

15-
# 2. Create Inputs
1614
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
1715
query = jax.random.normal(k1, (B, T, num_heads_q, D))
1816
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
1917
value = jax.random.normal(k3, (B, S, num_heads_kv, D))
2018

21-
# 3. Run Attention (Should not crash)
2219
output = nnx.dot_product_attention(query, key, value)
2320

24-
# 4. Verify Output Shape matches Query heads (6), not Key heads (3)
2521
assert output.shape == (B, T, num_heads_q, D)
2622

2723
def test_gqa_invalid_heads(self):
28-
# Test that it raises an error if heads aren't divisible
2924
B, T, D = 1, 4, 8
30-
query = jnp.ones((B, T, 5, D)) # 5 heads
31-
key = jnp.ones((B, T, 2, D)) # 2 heads (5 is not divisible by 2)
25+
query = jnp.ones((B, T, 5, D))
26+
key = jnp.ones((B, T, 2, D))
3227
value = key
3328

3429
try:
3530
nnx.dot_product_attention(query, key, value)
3631
assert False, "Should have raised ValueError"
3732
except ValueError as e:
38-
# Adjusted to match the actual error message in attention.py
3933
assert "must be multiple" in str(e)

0 commit comments

Comments
 (0)