Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,43 @@ def dot_product_attention_weights(

assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

# check if we need to broadcast Key heads to match Query heads
is_gqa = False
if query.shape[-2] != key.shape[-2]:
q_heads = query.shape[-2]
k_heads = key.shape[-2]

if q_heads % k_heads != 0:
raise ValueError(
f"Query heads ({q_heads}) must be multiple of "
f"Key heads ({k_heads}) for Grouped Query Attention."
)

n_rep = q_heads // k_heads
is_gqa = True
# Reshape Query: [..., Q, H_k * n_rep, D] -> [..., Q, H_k, n_rep, D]
query = query.reshape(query.shape[:-2] + (k_heads, n_rep, query.shape[-1]))
# Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
key = jnp.expand_dims(key, axis=-2)

# Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
einsum_str = '...qhgd,...kh1d->...hgqk'
else:
q_heads = query.shape[-2]
einsum_str = '...qhd,...khd->...hqk'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'

# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)

# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
)
attn_weights = jnp.einsum(einsum_str, query, key, precision=precision)

if is_gqa:
attn_weights = attn_weights.reshape(attn_weights.shape[:-4] + (q_heads, attn_weights.shape[-2], attn_weights.shape[-1]))

# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
Expand All @@ -145,9 +172,12 @@ def dot_product_attention_weights(
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
# use original key.ndim because we might have expanded key dim
ndim_base = key.ndim - 1 if is_gqa else key.ndim

if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
dropout_shape = tuple([1] * (ndim_base - 2)) + attn_weights.shape[-2:]
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
Expand Down Expand Up @@ -224,17 +254,15 @@ def dot_product_attention(
"""
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
dtype = query.dtype

assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'

# Criteria that invoke the more optimized dot product attention
if dropout_rate == 0.0 and module == None:
if dropout_rate == 0.0 and module is None:
# make sure qkv batch are compressed to one dim
query_shape = query.shape
if len(query_shape) > 4:
Expand Down Expand Up @@ -267,9 +295,28 @@ def reshape_4d(x):
)

# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
)
# check if need to broadcast Value heads to match Query heads (GQA)
if attn_weights.shape[-3] != value.shape[-2]:
q_heads = attn_weights.shape[-3]
v_heads = value.shape[-2]
if q_heads % v_heads != 0:
raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})")

n_rep = q_heads // v_heads
# Reshape weights: [..., H_v, n_rep, Q, K]
attn_weights = attn_weights.reshape(attn_weights.shape[:-3] + (v_heads, n_rep) + attn_weights.shape[-2:])
# Expand Value: [..., K, H_v, 1, D]
value = jnp.expand_dims(value, axis=-2)
# Contract: hgqk, kh1d -> qhgd (h=H_v, g=n_rep)
out = jnp.einsum('...hgqk,...kh1d->...qhgd', attn_weights, value, precision=precision)
# Flatten: [..., Q, H_q, D]
out = out.reshape(out.shape[:-3] + (q_heads, out.shape[-1]))
else:
out = jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
)

return out


class MultiHeadAttention(Module):
Expand Down
55 changes: 55 additions & 0 deletions tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,61 @@ def test_varying_num_features(self):

self.assertIsNotNone(layer(x, y))

class TestGQADotProductAttention(parameterized.TestCase):

def test_gqa_shapes(self):
B, T, S = 2, 4, 5
D = 8
num_heads_q = 6
num_heads_kv = 3

k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
query = jax.random.normal(k1, (B, T, num_heads_q, D))
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
value = jax.random.normal(k3, (B, S, num_heads_kv, D))

output = nnx.dot_product_attention(query, key, value)
expected_shape = (B, T, num_heads_q, D)
self.assertEqual(output.shape, expected_shape)

def test_gqa_invalid_heads(self):
B, T, D = 1, 4, 8
query = jnp.ones((B, T, 5, D))
key = jnp.ones((B, T, 2, D))
value = key

with self.assertRaisesRegex(ValueError, "must be a multiple"):
nnx.dot_product_attention(query, key, value)

def test_gqa_parity_with_jax(self):
class DummyModule(nnx.Module):
pass

dummy_module = DummyModule()

B, T, S, D = 2, 8, 8, 16
num_heads_q = 4
num_heads_kv = 2

rng = jax.random.key(42)
k1, k2, k3 = jax.random.split(rng, 3)

query = jax.random.normal(k1, (B, T, num_heads_q, D))
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
value = jax.random.normal(k3, (B, S, num_heads_kv, D))

jax_out = jax.nn.dot_product_attention(query, key, value)

# NNX should handle broadcasting internally
nnx_out = nnx.dot_product_attention(
query, key, value,
module=dummy_module
)

np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3)


if __name__ == '__main__':
absltest.main()


Loading