Skip to content

[Feature]: Enable MLA prefill-only operator #8424

@nzmora-nvidia

Description

@nzmora-nvidia

🚀 The feature, motivation and pitch

FlashInfer MLA MHA-mode (no weight absorption) with cache; ragged; for prefill-only. Uses flashinfer.BatchPrefillWithRaggedKVCacheWrapper

This kernel is optimized for prefill.
Its inputs are full heads Q/K/V so it performs MLA in MHA-mode (no absorption).
The operator should compute ckv and kpe and write them to the compressed cache, then continue to do MHA.
This can probably done with other flashinfer kernels.

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

Labels

AutoDeploy<NV> AutoDeploy BackendCustomized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.

Type

No type

Projects

Status

Backlog

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions