Skip to content

feat(nnx): Expose GQA support (num_key_value_heads) in MultiHeadAttention #5198

@ayulockedin

Description

@ayulockedin

Problem

PR #5180 adds GQA support to the functional nnx.dot_product_attention.
However, the nnx.MultiHeadAttention module currently enforces symmetry, hardcoding Key/Value projections to num_heads:

https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L380

This makes it impossible to define GQA layers (e.g., Llama 3 with 32 query heads and 8 KV heads) using the standard Module API.

Note: This implementation depends on the functional changes in PR #5180 and should be addressed after that PR is merged.

Proposed Solution

Update MultiHeadAttention to accept num_key_value_heads.

1. API Changes (__init__)

Add num_key_value_heads (defaulting to None for backward compatibility).

class MultiHeadAttention(Module):
  def __init__(
    self,
    num_heads: int,
    in_features: int,
    qkv_features: int | None = None,
    num_key_value_heads: int | None = None,  # New Arg
    # ... existing args
  ):
    self.num_heads = num_heads
    self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_heads

    # ... feature validation ...
    
    # GQA Validation
    if self.num_heads % self.num_key_value_heads != 0:
        raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_key_value_heads ({self.num_key_value_heads})")

    head_dim = self.qkv_features 

    # Projections
    self.query = Linear(in_features, self.num_heads * head_dim, ...)
    # Key/Value utilize the reduced head count
    self.key = Linear(in_features, self.num_key_value_heads * head_dim, ...) 
    self.value = Linear(in_features, self.num_key_value_heads * head_dim, ...)
    
    # Output projection aggregates all Query heads (unchanged)
    self.out = Linear(self.num_heads * head_dim, out_features, ...)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions