Skip to content

[AutoDeploy]: MLA optimizations for DS-R1 #8233

@nzmora-nvidia

Description

@nzmora-nvidia

🚀 The feature, motivation and pitch

  • Torch operators: MLA MHA-mode (no weight absorption) with and without cache.
  • FlashInfer operators: MLA MQA-mode (weight absorption) with cache; for decode and mixed decode+prefill (flattened). Uses flashinfer.mla.BatchMLAPagedAttentionWrapper and flashinfer.append_paged_mla_kv_cache
  • Weight absorption optimizer pass
  • Patch deepseek R1
  • Pytorch and FlashInfer MLA backends
  • Cuda graph for FlashInfer MLA operator
  • All correctness tests are passing

In a future task:

  • FlashInfer MLA MHA-mode (no weight absorption) with cache; ragged; for prefill-only. Uses flashinfer.BatchPrefillWithRaggedKVCacheWrapper
    This kernel is performant in prefill-only use-cases.
    To support mixed decode+prefill we need to:
    1. Compute new ckv + k_pe and append to the cache (paged)
    2. Read from the cache and write to a new ragged layout (paged cache has "holes") and the kernel only uses kv_indptr without lengths.
    3. Compute the output

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

Labels

AutoDeploy<NV> AutoDeploy Backend

Projects

Status

Backlog

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions