Skip to content

Conversation

@max-krasnyansky
Copy link
Collaborator

@max-krasnyansky max-krasnyansky commented Oct 28, 2025

A very simple dynamic Flash Attention chunking that splits the work into n_threads * 4 chunks.

This helps on platforms with significant performance difference between the CPU cores (ie big.LITTLE, boosted cores, etc) and it helps under heavy CPU load. Very similar to what MatMul and MatMul-ID chunking already does.

Flash Attention is a relatively small part of the overall profile so the end-to-end token rate is not affected that much but if I run it in isolation I see a nice bump in performance on the Gen5.

## Snapdragon Gen5 LLama3.2 3B Q4_0 (most Ops but FA are disabled)
before
llama_perf_context_print: prompt eval time =     258.31 ms /   205 tokens (    1.26 ms per token,   793.61 tokens per second)
llama_perf_context_print:        eval time =     499.05 ms /    63 runs   (    7.92 ms per token,   126.24 tokens per second)

after
llama_perf_context_print: prompt eval time =     216.11 ms /   205 tokens (    1.05 ms per token,   948.60 tokens per second)
llama_perf_context_print:        eval time =     477.52 ms /    63 runs   (    7.58 ms per token,   131.93 tokens per second)

## Snapdragon Gen5 LLama3.2 1B Q4_0 (most Ops but FA are disabled)
before
llama_perf_context_print: prompt eval time =     171.04 ms /   205 tokens (    0.83 ms per token,  1198.56 tokens per second)
llama_perf_context_print:        eval time =     290.58 ms /    63 runs   (    4.61 ms per token,   216.81 tokens per second)

after
llama_perf_context_print: prompt eval time =     164.80 ms /   205 tokens (    0.80 ms per token,  1243.91 tokens per second)
llama_perf_context_print:        eval time =     285.91 ms /    63 runs   (    4.54 ms per token,   220.35 tokens per second)

Also tested on the M4 Pro where I don't see any performance changes on the unloaded system but the loaded system is a different story.
Here are some more details with additional instrumentation that measures how many chunks each thread processed and how long it took.
You can see how under load some threads process more chunks in about the same amount of time on the M4 Pro.
On the Gen5 you can see that one of the cores crunches through many more chunks than the other cores.
The picture is similar on the Gen4 (8-Elite).

M4 Pro (GPT-OSS-20B) 6 threads
Under heavy load (compiling llama.cpp with x86-64 android-ndk)
thread-3: fa __fattn__-23 proc-chunks 4 proc-usec 3440
thread-4: fa __fattn__-23 proc-chunks 4 proc-usec 3518
thread-1: fa __fattn__-23 proc-chunks 4 proc-usec 3550
thread-0: fa __fattn__-23 proc-chunks 4 proc-usec 3615
thread-2: fa __fattn__-23 proc-chunks 4 proc-usec 3680
thread-5: fa __fattn__-23 proc-chunks 4 proc-usec 3891
thread-5: fa __fattn__-0 proc-chunks 4 proc-usec 3137
thread-0: fa __fattn__-0 proc-chunks 4 proc-usec 3178
thread-3: fa __fattn__-0 proc-chunks 4 proc-usec 3241
thread-4: fa __fattn__-0 proc-chunks 5 proc-usec 3857
thread-1: fa __fattn__-0 proc-chunks 5 proc-usec 3956
thread-2: fa __fattn__-0 proc-chunks 2 proc-usec 4815
thread-3: fa __fattn__-1 proc-chunks 5 proc-usec 4924
thread-5: fa __fattn__-1 proc-chunks 2 proc-usec 5611
thread-4: fa __fattn__-1 proc-chunks 3 proc-usec 5713
thread-2: fa __fattn__-1 proc-chunks 6 proc-usec 5735
thread-1: fa __fattn__-1 proc-chunks 6 proc-usec 5853
thread-0: fa __fattn__-1 proc-chunks 2 proc-usec 6049
thread-0: fa __fattn__-2 proc-chunks 4 proc-usec 3204
thread-4: fa __fattn__-2 proc-chunks 4 proc-usec 3309
thread-5: fa __fattn__-2 proc-chunks 2 proc-usec 3374
thread-2: fa __fattn__-2 proc-chunks 5 proc-usec 3915
thread-3: fa __fattn__-2 proc-chunks 5 proc-usec 3999
thread-1: fa __fattn__-2 proc-chunks 4 proc-usec 5146
thread-5: fa __fattn__-3 proc-chunks 4 proc-usec 3829
thread-2: fa __fattn__-3 proc-chunks 4 proc-usec 3973
thread-3: fa __fattn__-3 proc-chunks 5 proc-usec 4420
thread-4: fa __fattn__-3 proc-chunks 4 proc-usec 4615
thread-0: fa __fattn__-3 proc-chunks 5 proc-usec 4732
thread-1: fa __fattn__-3 proc-chunks 2 proc-usec 4775

Snapdragon 8E Gen5 (GPT-OSS-20B) 6 threads
thread-4: fa __fattn__-0 proc-chunks 2 proc-usec 4476
thread-0: fa __fattn__-0 proc-chunks 12 proc-usec 4565
thread-2: fa __fattn__-0 proc-chunks 2 proc-usec 4530
thread-5: fa __fattn__-0 proc-chunks 2 proc-usec 4720
thread-1: fa __fattn__-0 proc-chunks 3 proc-usec 6863
thread-3: fa __fattn__-0 proc-chunks 3 proc-usec 7170
thread-3: fa __fattn__-1 proc-chunks 2 proc-usec 5105
thread-0: fa __fattn__-1 proc-chunks 14 proc-usec 5242
thread-1: fa __fattn__-1 proc-chunks 2 proc-usec 5285
thread-4: fa __fattn__-1 proc-chunks 2 proc-usec 5435
thread-2: fa __fattn__-1 proc-chunks 2 proc-usec 5478
thread-5: fa __fattn__-1 proc-chunks 2 proc-usec 5593
thread-1: fa __fattn__-2 proc-chunks 2 proc-usec 4740
thread-0: fa __fattn__-2 proc-chunks 13 proc-usec 4827
thread-5: fa __fattn__-2 proc-chunks 2 proc-usec 4831
thread-4: fa __fattn__-2 proc-chunks 2 proc-usec 4894
thread-2: fa __fattn__-2 proc-chunks 2 proc-usec 5439
thread-3: fa __fattn__-2 proc-chunks 3 proc-usec 7006
thread-2: fa __fattn__-3 proc-chunks 2 proc-usec 3843
thread-5: fa __fattn__-3 proc-chunks 2 proc-usec 4030
thread-0: fa __fattn__-3 proc-chunks 11 proc-usec 4111
thread-4: fa __fattn__-3 proc-chunks 3 proc-usec 5664
thread-1: fa __fattn__-3 proc-chunks 3 proc-usec 5795
thread-3: fa __fattn__-3 proc-chunks 3 proc-usec 5820

Galaxy S25+ (Llama 3.2 3B) 6 threads
thread-0: fa __fattn__-10 proc-chunks 6 proc-usec 78
thread-5: fa __fattn__-10 proc-chunks 3 proc-usec 80
thread-2: fa __fattn__-10 proc-chunks 3 proc-usec 80
thread-4: fa __fattn__-10 proc-chunks 3 proc-usec 80
thread-1: fa __fattn__-10 proc-chunks 3 proc-usec 80
thread-3: fa __fattn__-10 proc-chunks 6 proc-usec 78
thread-0: fa __fattn__-11 proc-chunks 6 proc-usec 75
thread-5: fa __fattn__-11 proc-chunks 6 proc-usec 75
thread-1: fa __fattn__-11 proc-chunks 3 proc-usec 78
thread-3: fa __fattn__-11 proc-chunks 3 proc-usec 78
thread-2: fa __fattn__-11 proc-chunks 3 proc-usec 78
thread-4: fa __fattn__-11 proc-chunks 3 proc-usec 78

I'm going to submit a couple more related PRs:

  • Enabling CPU MatMul-ID chunking on ARM64
  • Introducing a very similar chunking ie chunk_size = nrows / (n_threads * 4) for the Repack MatMuls.

Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
@max-krasnyansky
Copy link
Collaborator Author

@slaren this one as well along with the matmul chunking

@slaren
Copy link
Member

slaren commented Oct 30, 2025

On a 13900k this also has an overall positive effect on performance when using a high number of threads.
These tests are with 30 threads.

Model Test t/s master t/s flashattn-chunking Speedup
llama 1B Q4_K_M pp64@d1024 484.66 525.76 1.08
Backend GGML op Op parameters GFLOPS master GFLOPS flashattn-chunking Speedup
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 63.50 57.74 0.91
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 69.81 66.95 0.96
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 67.54 67.07 0.99
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 96.79 167.00 1.73
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 132.20 238.40 1.80
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 129.70 231.80 1.79
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 55.22 53.06 0.96
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 54.46 50.20 0.92
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 54.12 51.64 0.95
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 103.70 187.30 1.81
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 106.90 183.50 1.72
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 101.40 182.10 1.80

@ggerganov
Copy link
Member

M4 Max results:

CMAKE_OPTS="-DGGML_METAL=OFF -DGGML_BLAS=OFF" ./scripts/compare-commits.sh master pr/16829 test-backend-ops -o FLASH_ATTN_EXT -b CPU
Backend GGML op Op parameters GFLOPS master GFLOPS pr/16829 Speedup
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 92.19 96.66 1.05
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 96.75 97.40 1.01
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 97.04 108.30 1.12
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 114.20 228.70 2.00
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 98.92 308.20 3.12
CPU FLASH_ATTN_EXT hsk=128,hsv=128,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 123.80 319.00 2.58
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 70.83 98.51 1.39
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 81.42 80.80 0.99
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 68.26 84.16 1.23
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 96.06 251.40 2.62
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 97.40 228.70 2.35
CPU FLASH_ATTN_EXT hsk=64,hsv=64,nh=8,nr23=[4,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3] 81.27 204.40 2.52

@ggerganov ggerganov merged commit dcca0d3 into ggml-org:master Oct 30, 2025
129 of 130 checks passed
@max-krasnyansky max-krasnyansky deleted the flashattn-chunking branch October 30, 2025 16:19
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Oct 31, 2025
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Oct 31, 2025
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Oct 31, 2025
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Oct 31, 2025
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Oct 31, 2025
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants