-
Notifications
You must be signed in to change notification settings - Fork 429
Sink attention AoT #1427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Sink attention AoT #1427
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @nandor, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've implemented support for Ahead-of-Time (AoT) compilation of 'sink attention' mechanisms within the FlashInfer library. This involved a significant refactoring of the JIT module generation to utilize iterators for improved efficiency, the addition of specific C++ kernel declarations to define the attention sink logic, and the integration of these new modules into the build and testing processes. The changes enable more specialized and optimized attention computations, particularly for handling long sequences in models.
Highlights
- Refactored JIT Module Generation to Use Iterators: I've refactored the JIT module generation process by changing
gen_fa2
,gen_fa3
, andgen_attention
functions to returnIterator[JitSpec]
instead ofList[JitSpec]
. This change improves memory efficiency and potentially speeds up AoT compilation by processing JIT specifications one by one. - Introduced Dedicated Attention Sink Kernels: I've introduced new C++ kernel declarations for the
AttentionSink
struct, specifically tailored for both FA2 and FA3 backends. These declarations define the core logic for how 'sink attention' is handled during prefill operations at a low level. - Integrated New Attention Sink Modules for AoT Compilation: I've implemented
gen_batch_prefill_attention_sink_module
and its corresponding URI generation. This new module allows for the Ahead-of-Time (AoT) compilation of batch prefill operations that incorporate the attention sink mechanism, enabling more specialized and optimized attention computations. - Updated Module Generation Logic for Attention Sink: The
gen_attention
function has been updated to selectively generate attention sink modules for 'OAI OSS' configurations. This leverages the newly introducedgen_batch_prefill_attention_sink_module
to ensure that attention sink capabilities are included where needed. - Centralized C++ Kernel Declarations and Cleaned Up Test Code: I've cleaned up the
test_attention_sink.py
file by removing redundant C++ kernel declarations. These declarations are now centrally defined inflashinfer/jit/attention/variants.py
, promoting code reusability and maintainability.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. β©
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Ahead-of-Time (AOT) compilation support for sink attention. The changes include refactoring the AOT generation code to use iterators for better memory efficiency, which is a positive improvement. The core logic for sink attention is encapsulated in a new variants.py
file and exposed through new JIT helper functions, which is well-structured. However, I've identified a type hint error in one of the refactored generator functions. Additionally, the accompanying test file test_attention_sink.py
has some issues: its test coverage has been significantly reduced, and there's a bug in how it initializes the JIT wrapper, which will prevent it from using the pre-warmed kernels. These issues in the test suite should be addressed to ensure the new functionality is robust and correctly validated.
f5c6ada
to
02657dc
Compare
jit_specs.append(gen_vllm_comm_module()) | ||
jit_specs.append(gen_nvshmem_module()) | ||
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should raise here. Because if add_comm
explicitly request for comm.
jit_specs.append(gen_trtllm_comm_module()) | ||
jit_specs.append(gen_vllm_comm_module()) | ||
except ImportError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
I guess maybe a reasonable approach would be split the add_comm
into two flags.
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.π§ͺ Tests
unittest
, etc.).Reviewer Notes