cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode#19478
cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode#19478JohannesGaessler merged 3 commits intoggml-org:masterfrom
Conversation
IMbackK
left a comment
There was a problem hiding this comment.
Also on hip all rDNA devices are wave32-only, wave64 mode is not available via hip.
ggml/src/ggml-cuda/mmvq.cu
Outdated
| default: | ||
| return 1; | ||
| } | ||
| } |
There was a problem hiding this comment.
Since you set #define MMVQ_RDNA4_ROWS 1 this entire block is pointless
879ccbb to
49a5ff4
Compare
ggml/src/ggml-cuda/mmvq.cu
Outdated
| if (table_id == MMVQ_PARAMETERS_RDNA4) { | ||
| switch (ncols_dst) { | ||
| case 1: | ||
| return MMVQ_RDNA4_NWARPS; |
There was a problem hiding this comment.
| return MMVQ_RDNA4_NWARPS; | |
| return 8; |
Preferably just hard-code the value here if you're only going to use it once.
Performance
In my testing this change is not beneficial for all data types, please add a corresponding check. Also, you may want to test whether you can squeeze out a bit more performance for batch sizes 2-8 with the same change since those are relevant for speculative decoding and batched inference. |
|
I forgot: I tested the code like this in case you want to replicate it: for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k_s q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./bench --model models/opt/${mn}-${q}.gguf -r 1 -fa 1 -n 0 -ub "1-16*2" --progress -o sql|sqlite3 llama-bench.sqlite; sleep 10; done
py scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b b7981|tee bench.txt |
49a5ff4 to
1562bfd
Compare
ggml/src/ggml-cuda/mmvq.cu
Outdated
| // Returns true for quantization types that benefit from nwarps=8 on RDNA4. | ||
| // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register | ||
| // pressure and lookup table contention at higher thread counts. | ||
| static constexpr __host__ __device__ bool mmvq_rdna4_nwarps8_beneficial(ggml_type type) { | ||
| switch (type) { | ||
| case GGML_TYPE_Q4_0: | ||
| case GGML_TYPE_Q4_1: | ||
| case GGML_TYPE_Q5_0: | ||
| case GGML_TYPE_Q5_1: | ||
| case GGML_TYPE_Q8_0: | ||
| case GGML_TYPE_Q2_K: | ||
| case GGML_TYPE_Q4_K: | ||
| case GGML_TYPE_Q5_K: | ||
| case GGML_TYPE_Q6_K: | ||
| case GGML_TYPE_IQ4_NL: | ||
| case GGML_TYPE_IQ4_XS: | ||
| return true; | ||
| default: | ||
| return false; | ||
| } | ||
| } |
There was a problem hiding this comment.
This function is only used once, so just move this logic into calc_nwarps.
1562bfd to
b1f4a5f
Compare
|
@IMbackK Thanks for the review!
|
|
As is outlined in the contributing guidelines:
|
dad0789 to
2df360d
Compare
|
@JohannesGaessler |
|
On Strix Halo the changes seem to be detrimental:
Please revert the RDNA3 changes unless you can post the data based on which you think they're beneficial. |
|
@JohannesGaessler Thanks for pointing out the Strix Halo regression. We reproduced it on a local device — nwarps=1 is already optimal, higher values degrade monotonically. The nwarps tuning is a latency-hiding optimization that works on discrete GPUs where memory bandwidth is underutilized. On Strix Halo with shared LPDDR5X, bandwidth is already near saturation at nwarps=1, so extra warps just add reduction overhead with no benefit. Strix Halo (RDNA3.5) should be excluded from the RDNA3 nwarps=8 table. |
ggml/src/ggml-cuda/mmvq.cu
Outdated
| MMVQ_PARAMETERS_GCN, | ||
| MMVQ_PARAMETERS_RDNA2 | ||
| MMVQ_PARAMETERS_RDNA2, | ||
| MMVQ_PARAMETERS_RDNA3, |
There was a problem hiding this comment.
| MMVQ_PARAMETERS_RDNA3, | |
| MMVQ_PARAMETERS_RDNA3_0, |
ggml/src/ggml-cuda/mmvq.cu
Outdated
| #elif defined(RDNA3) && !defined(RDNA3_5) | ||
| return MMVQ_PARAMETERS_RDNA3; |
There was a problem hiding this comment.
| #elif defined(RDNA3) && !defined(RDNA3_5) | |
| return MMVQ_PARAMETERS_RDNA3; | |
| #elif defined(RDNA3_0) | |
| return MMVQ_PARAMETERS_RDNA3_0; |
Define an RDNA3_0 macro that can be used in the same, consistent way as in host code and use that instead.
|
@JohannesGaessler |
|
As I said before, please post the data based on which you are making these changes. If you don't I will require someone else with access to an RDNA3 GPU to benchmark these changes. |
|
@JohannesGaessler i self assigned this with the purpose of benchmarking it on my machine (which now contains a gfx1100 device), @JoursBleu please also provide the benchmarks you performed. |
|
Hi @JohannesGaessler, we tested the 20 quantized model types listed above on W7900, and the results are shown below.
|
|
Hi @JohannesGaessler @IMbackK, just checking to confirm — the benchmark data for the W7900 has been posted. @IMbackK, could you run the benchmarks for the gfx1100? I'd be happy to assist you if needed. And let me know if there are any remaining concerns. Thanks! |
|
Hello, took me a day to narrow this down. However, on RDNA4 x2 R9700 and 3x R9700 , depending ROCM variables, this is leading to anywhere between 10%-17%. TG reduction for Qwen 3.5 120b Q4. Tested 8354 of which TPS was back at 40. |
|
x2 and 3x? |
x3 Radeon AI Pro R9700. AKA not ideal setup but works. Due to Motherboard having. 12x DMMS so limited to 3x PCIE5 X16 and need to change somthing.. |
Hi @meven3000, thank you for your feedback. I will try to reproduce and fix this issue. |
…org#19478) * mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1) * mmvq: add dedicated RDNA3 parameter table * mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table
|
Thanks for the feedback @meven3000! We did further testing and found the regression is not caused by multi-GPU parallelism, but by MoE models. Qwen3.5-35B-A3B Q4_K_M — single GPU TG
Qwen3.5-122B-A10B Q4_K_M — 4x GPU TG
Qwen2.5-72B Q4_K_M (dense) — 4x GPU TG
We have an initial fix in #20831. More comprehensive testing is in progress. |
Add a dedicated
MMVQ_PARAMETERS_RDNA4entry separate from RDNA2/RDNA3. RDNA4 (gfx1201) is wave32-only and has a different memory subsystem, so it benefits from a different MMVQ configuration than RDNA2/RDNA3.For bs=1 decode on RDNA4, optimal config is
nwarps=8, rows_per_block=1:blocks_per_iter = vdr * nwarps * warp_size / qi = 2 * 8 * 32 / 4 = 128blocks_per_row = 128, the entire K dimension is covered in a single iterationBenchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700 / gfx1201):
Correctness:
test-backend-ops -o MUL_MAT: 1009/1009 OKllama-perplexity(wikitext-2, 5 chunks): 6.1736 ± 0.40149 (identical to master)