-
Notifications
You must be signed in to change notification settings - Fork 76
[FA]:Optimize FlashAttention for N_CTX <= 512 #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The code change LGTM. |
Actually, I'm still working on figuring it out. It's puzzling. |
|
Did you get the idea from XeTLA, i.e., is the implementation closer to XeTLA now? |
It was mainly by many tries with profiling data. XeTLA has a different arange that's suitable for all the shapes. |
This change (`grid` order adjustment to improve cache hit) originating from #2600. Batched gemm only. ~99% of XeTLA for `4096x8x128x16384`. 
The reason is: we use |

This change can make the cache behavior of N_CTX=512 better. We can get 20%+ performance gain from the change, but it may be harmful to large N_CTX. So restrict this change to N_CTX <= 512. CI data
