Skip to content

Implement Selective Attention for Memory-Efficient Inference #12817

@cdutr

Description

@cdutr

[Feature Request] Implement Selective Attention for Memory-Efficient Inference

Is your feature request related to a problem? Please describe.

Yes, it is hard to generate longer videos or process high-resolution sequences with video diffusion models (CogVideoX, Mochi, LTX, Hunyuan Video, Wan 2.2). These models frequently run out of memory on consumer GPUs due to attention's quadratic memory scaling with sequence length. Current solutions (sliced attention, CPU offloading) sacrifice too much speed for memory savings.

Describe the solution you'd like.

Implement Selective Attention (ICLR 2025) as a new attention processor that achieves 16-47X memory reduction with maintained quality.

Describe alternatives you've considered.

I've considered sliced attention (too slow), xFormers (CUDA-only), Flash Attention 3 (Hopper-only), and CPU offloading (extremely slow).

Additional context.

Detailed proposal follows below.


Summary

Implement Selective Attention (ICLR 2025), a parameter-free attention mechanism that achieves 16-47X memory reduction during inference with equivalent performance to standard attention. This would particularly benefit video diffusion models (CogVideoX, Mochi, LTX, Hunyuan Video, Wan) which process long sequences.

Motivation

Why Selective Attention?

  1. Parameter-free: No model retraining required - can be applied to existing checkpoints
  2. Massive memory savings: 16X, 25X, and 47X less memory for context sizes of 512, 1024, and 2048 respectively
  3. No quality loss: Maintains same perplexity as transformers with ~2X more attention parameters (validated on language modeling benchmarks in the paper)
  4. Perfect for video generation: Long temporal sequences are the primary bottleneck in video diffusion

Current Pain Points in Diffusers

Video generation models are severely memory-constrained:

  • Long sequences (192+ tokens for multi-image inputs, thousands for video frames)
  • Multi-head attention scales quadratically with sequence length
  • Current solutions (sliced attention, CPU offloading) sacrifice speed for memory
  • Video models like Wan 2.2 frequently OOM on consumer GPUs (#12011, #12613)

Proposed Implementation

Architecture Integration

Add SelectiveAttnProcessor2_0 following existing processor patterns:

class SelectiveAttnProcessor2_0:
    """
    Selective attention processor using PyTorch 2.0's SDPA.

    Based on "Selective Attention Improves Transformer" (ICLR 2025)
    https://arxiv.org/abs/2410.02703
    """
    def __init__(self, selection_threshold: float = 0.1):
        self.selection_threshold = selection_threshold

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        # Implementation details following the paper's algorithm
        # 1. Compute selection scores (lightweight attention preview)
        # 2. Select top-k relevant tokens based on threshold
        # 3. Apply full attention only to selected subset
        # 4. Reconstruct full output
        pass

Expected Benefits

Memory Reduction Examples

For a video model processing 192 frames:

  • Standard attention: ~73,728 elements in attention matrix
  • With selective attention (top 25%): ~18,432 elements
  • Memory saved: ~75% reduction in attention buffer

For Wan 2.2 14B processing 1024 token sequence:

  • Standard attention: 1,048,576 elements per head
  • With selective attention: 65,536-262,144 elements (depending on threshold)
  • Enables: Longer videos on consumer GPUs

The paper demonstrates no quality loss on language modeling tasks. Comprehensive benchmarking on diffusion models (FID, LPIPS, visual quality) would be part of the implementation to validate results transfer to vision generation.

Use Cases

  1. Longer video generation: Generate 10-30s videos instead of 5-10s on same hardware
  2. Higher resolution: Process more frames at higher resolution
  3. Batch size increase: Generate more samples in parallel
  4. Mobile/edge deployment: Run video models on resource-constrained devices

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions