-
Notifications
You must be signed in to change notification settings - Fork 788
Open
Description
Problem
PR #5180 adds GQA support to the functional nnx.dot_product_attention.
However, the nnx.MultiHeadAttention module currently enforces symmetry, hardcoding Key/Value projections to num_heads:
https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L380
This makes it impossible to define GQA layers (e.g., Llama 3 with 32 query heads and 8 KV heads) using the standard Module API.
Note: This implementation depends on the functional changes in PR #5180 and should be addressed after that PR is merged.
Proposed Solution
Update MultiHeadAttention to accept num_key_value_heads.
1. API Changes (__init__)
Add num_key_value_heads (defaulting to None for backward compatibility).
class MultiHeadAttention(Module):
def __init__(
self,
num_heads: int,
in_features: int,
qkv_features: int | None = None,
num_key_value_heads: int | None = None, # New Arg
# ... existing args
):
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_heads
# ... feature validation ...
# GQA Validation
if self.num_heads % self.num_key_value_heads != 0:
raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_key_value_heads ({self.num_key_value_heads})")
head_dim = self.qkv_features
# Projections
self.query = Linear(in_features, self.num_heads * head_dim, ...)
# Key/Value utilize the reduced head count
self.key = Linear(in_features, self.num_key_value_heads * head_dim, ...)
self.value = Linear(in_features, self.num_key_value_heads * head_dim, ...)
# Output projection aggregates all Query heads (unchanged)
self.out = Linear(self.num_heads * head_dim, out_features, ...)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels