Skip to content

Conversation

@quintinwang5
Copy link
Contributor

@quintinwang5 quintinwang5 commented Oct 31, 2024

This change can make the cache behavior of N_CTX=512 better. We can get 20%+ performance gain from the change, but it may be harmful to large N_CTX. So restrict this change to N_CTX <= 512. CI data
image

@etiotto etiotto changed the title [FA]:Optimize FlalshAttention for N_CTX <= 512 [FA]:Optimize FlashAttention for N_CTX <= 512 Oct 31, 2024
@chengjunlu
Copy link
Contributor

The code change LGTM.
Can you explain the difference of the memory foot print between the old one and the new one?

@quintinwang5
Copy link
Contributor Author

The code change LGTM. Can you explain the difference of the memory foot print between the old one and the new one?

Actually, I'm still working on figuring it out. It's puzzling.

@quintinwang5 quintinwang5 merged commit b8fc4b9 into main Nov 1, 2024
5 checks passed
@quintinwang5 quintinwang5 deleted the quintin/perf_n_ctx_512 branch November 1, 2024 01:12
@whitneywhtsang
Copy link
Contributor

Did you get the idea from XeTLA, i.e., is the implementation closer to XeTLA now?

@quintinwang5
Copy link
Contributor Author

Did you get the idea from XeTLA, i.e., is the implementation closer to XeTLA now?

It was mainly by many tries with profiling data. XeTLA has a different arange that's suitable for all the shapes.

yudongsi added a commit that referenced this pull request Nov 11, 2024
This change (`grid` order adjustment to improve cache hit) originating
from #2600.
Batched gemm only.
~99% of XeTLA for `4096x8x128x16384`.

![image](https://github.com/user-attachments/assets/ef7e9750-b3f7-4adc-aa66-5be704383e40)
@quintinwang5
Copy link
Contributor Author

The code change LGTM. Can you explain the difference of the memory foot print between the old one and the new one?

The reason is: we use BLOCK_M to split Q along Y-axis. For different BLOCK_M, they shares the same K, V. So if these workgroups which process different BLOCK_M are scheduled continuously, K, V's cache can be reused. Taking (32, 32, 512, 64) as an example, the old nd_range is {(4, 32, 32), (128, 1, 1)}. There are 4 x 128(BLOCK_M) = 512 blocks. 4 blocks are not consecutive, the stride is 32x32 workgroups. If we change it to {(32, 32, 4), (128, 1, 1)}, 4 blocks are consecutive now.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FA] Improve performance of shapes <95% on advanced path - 32x32x512, 4x32x4096, 2x32x8192

5 participants