Skip to content

Feature/efficient attention#1189

Open
divye-joshi wants to merge 3 commits intogoogle-research:mainfrom
divye-joshi:feature/efficient-attention
Open

Feature/efficient attention#1189
divye-joshi wants to merge 3 commits intogoogle-research:mainfrom
divye-joshi:feature/efficient-attention

Conversation

@divye-joshi
Copy link

PR: Add Memory-Efficient Attention Mechanisms

Summary

Adds FlashAttention and LinearAttention to scenic/model_lib/layers/efficient_attention_layers.py. These implementations resolve $O(N^2)$ memory bottlenecks for long-sequence tasks like high-res vision and document NLP.

Attention Type Time Memory Use Case
Standard $O(N^2)$ $O(N^2)$ Short sequences
FlashAttention $O(N^2)$ $O(N)$ Memory-constrained (Exact)
LinearAttention $O(N)$ $O(N)$ Very long sequences (Approx)

Key Features

  • Drop-in Replacement: API matches nn.MultiHeadDotProductAttention.
  • Purely Additive: Zero changes to existing Scenic core code or dependencies.
  • Performance: Enables training on sequences that previously triggered Out-of-Memory (OOM) errors.

Usage

from scenic.model_lib.layers.efficient_attention_layers import FlashAttention

# Swap nn.MultiHeadDotProductAttention with:
attn = FlashAttention(num_heads=8, qkv_features=512)

Implementation Details

  • FlashAttention: Uses the online softmax algorithm for block-wise computation. Includes a fallback to standard attention for sequences smaller than the block size (default: 512).
  • LinearAttention: Implements kernel-based attention ($elu(x) + 1$) per Katharopoulos et al. (2020).
  • Limitation: FlashAttention currently bypasses dropout for sequences exceeding block size; this is documented in the docstrings.

Testing & Validation

  • Coverage: 30 tests passed covering numerical correctness (vs. standard attention), gradient flow, masking, and shape consistency.
  • Stability: Verified against large values and fully masked inputs.

Checklist

  • Code follows project style
  • Comprehensive docstrings added
  • All 30 unit tests passing
  • No new external dependencies

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant