Skip to content

Conversation

@ZelinMa557
Copy link

Modern LLMs (Llama3, qwen 2.5, etc) usually use group query attention, which significantly reduces memory usage caused by KV cache. Group query attention means that query rows of neighbor query heads share kv rows of the same kv head, so we can reorder the loop to:

// python style pseudo code
for group_id in (0,group_num):
    for seq_id in (0, seq_length):
        k = load_k(group_id, seq_id)
        v = load_v(group_id, seq_id)
        for head_id in (group_id * n_gqa, group_id * n_gqa +n_gqa):
              q = load_q(head_id, seq_id)
              compute(q, k, v)

to improve spatial locality of memory access. However the original implemention of cpu flash attention kernel didn't consider that, and this pr improves it.

This is my test command:

./build/bin/llama-cli -t 4 -fa --ctx-size 8192 -m models/Qwen2.5-Coder-7B-Instruct-Q2_K.gguf -f convert_lora_to_gguf.py

The mastrer branch result:

llama_perf_sampler_print:    sampling time =      45.59 ms /  4647 runs   (    0.01 ms per token, 101939.19 tokens per second)
llama_perf_context_print:        load time =     687.54 ms
llama_perf_context_print: prompt eval time =  588053.13 ms /  4412 tokens (  133.28 ms per token,     7.50 tokens per second)
llama_perf_context_print:        eval time =   71929.76 ms /   234 runs   (  307.39 ms per token,     3.25 tokens per second)
llama_perf_context_print:       total time =  660956.03 ms /  4646 tokens
Interrupted by user

With the optimization, the result is:

llama_perf_sampler_print:    sampling time =      51.55 ms /  4688 runs   (    0.01 ms per token, 90949.66 tokens per second)
llama_perf_context_print:        load time =     901.74 ms
llama_perf_context_print: prompt eval time =  531273.62 ms /  4412 tokens (  120.42 ms per token,     8.30 tokens per second)
llama_perf_context_print:        eval time =   63472.48 ms /   275 runs   (  230.81 ms per token,     4.33 tokens per second)
llama_perf_context_print:       total time =  595681.19 ms /  4687 tokens
Interrupted by user

We can see 10% speed up in prefill, and 33% speed up in decode!

Futher work:

  1. flash decoding: in this pr, when n_kv_head < thread num, and there is only one cocurrent request, this cpu kernel cannot use all the threads. we can solve this by using flash decoding.
  2. load balancing between threads: in causual attention, the computation amout is different between rows, but the current implemention dosn't take that into consideration, which slows down the multi-threaded long-context prefill speed.

My test environment:

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         39 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  8
  On-line CPU(s) list:   0-7
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
    CPU family:          6
    Model:               142
    Thread(s) per core:  2
    Core(s) per socket:  4
    Socket(s):           1
    Stepping:            12
    BogoMIPS:            4607.99
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse
                          sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclm
                         ulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor la
                         hf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase bmi1 avx2
                          smep bmi2 erms invpcid rdseed adx smap clflushopt xsaveopt xsavec xgetbv1 xsaves md_clear f
                         lush_l1d arch_capabilities
Virtualization features: 
  Hypervisor vendor:     Microsoft
  Virtualization type:   full
Caches (sum of all):     
  L1d:                   128 KiB (4 instances)
  L1i:                   128 KiB (4 instances)
  L2:                    1 MiB (4 instances)
  L3:                    8 MiB (1 instance)
Vulnerabilities:         
  Itlb multihit:         KVM: Mitigation: VMX unsupported
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl and seccomp
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
  Srbds:                 Mitigation; TSX disabled
  Tsx async abort:       Not affected

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 5, 2025
@ZelinMa557 ZelinMa557 closed this May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant