-
Notifications
You must be signed in to change notification settings - Fork 13.4k
metal : optimize FA kernels #10171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
metal : optimize FA kernels #10171
Conversation
c71e0bc to
d0cff71
Compare
a797e5d to
f66d362
Compare
ff1b4f5 to
5464b08
Compare
|
This PR should be gucci now. |
|
The performance increase looks about the same with M3 Max:
Note that to test batch sizes larger than the default 2048 with |
a49913f to
5d1a10d
Compare
ggml-ci
59792ff to
1888c1f
Compare
Thanks, I forgot about that. As a data point, running some tests as a function of the ./llama-bench -m ./models/llama-3.2-3b-instruct/ggml-model-f16.gguf -fa 1 -p 1024,2048,4096,8192,16384 -b 16384 -ub 512,1024,2048,4096,8192 -n 0
build: 59792ff (4057) ./llama-bench -m ./models/qwen2.5-7b-coder/ggml-model-q8_0.gguf -fa 1 -p 1024,2048,4096,8192,16384 -b 16384 -ub 512,1024,2048,4096,8192 -n 0
build: 1888c1f (4057) My guess is that the logic for skipping the computation of attention blocks when the mask is full of -INF in that block is now more efficient. I'm wondering if this optimization could be viable for the CUDA FA as well. |
1888c1f to
bc143ec
Compare
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
tgt #10149
rel #8439
Various optimizations for the FA kernels:
The performance should be noticeably better at larger contexts. The kernels continue to use F32 accumulators for the
Q*K*scaleso I hope there are no floating-point range issues. Though some extra testing won't hurt.The original idea of using full
BF16math in the FA kernels did not produce satisfactory results. I think thatbfloatperformance is not great on Metal yet.Here are some benches:
Using
llama-batched-benchto show TG speed after large prompts (S_TGcolumn):mastergg/metal-fa-f16M1 Pro
TODO: