Skip to content

cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode#19478

Merged
JohannesGaessler merged 3 commits intoggml-org:masterfrom
JoursBleu:pr/mmvq-rdna4
Mar 15, 2026
Merged

cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode#19478
JohannesGaessler merged 3 commits intoggml-org:masterfrom
JoursBleu:pr/mmvq-rdna4

Conversation

@JoursBleu
Copy link
Copy Markdown
Contributor

Add a dedicated MMVQ_PARAMETERS_RDNA4 entry separate from RDNA2/RDNA3. RDNA4 (gfx1201) is wave32-only and has a different memory subsystem, so it benefits from a different MMVQ configuration than RDNA2/RDNA3.

For bs=1 decode on RDNA4, optimal config is nwarps=8, rows_per_block=1:

  • 8 warps × 32 threads = 256 threads per block
  • blocks_per_iter = vdr * nwarps * warp_size / qi = 2 * 8 * 32 / 4 = 128
  • For K=4096: blocks_per_row = 128, the entire K dimension is covered in a single iteration
  • This maximizes memory-level parallelism on RDNA4's memory controller

Benchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700 / gfx1201):

Metric Master This PR Change
tg128 (tok/s) 95.05 104.82 +10.3%
pp512 (tok/s) 1449 1448 no regression

Correctness:

  • test-backend-ops -o MUL_MAT: 1009/1009 OK
  • llama-perplexity (wikitext-2, 5 chunks): 6.1736 ± 0.40149 (identical to master)
  • Text generation (greedy decoding): bit-exact match with master

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 10, 2026
Copy link
Copy Markdown
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also on hip all rDNA devices are wave32-only, wave64 mode is not available via hip.

default:
return 1;
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you set #define MMVQ_RDNA4_ROWS 1 this entire block is pointless

if (table_id == MMVQ_PARAMETERS_RDNA4) {
switch (ncols_dst) {
case 1:
return MMVQ_RDNA4_NWARPS;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return MMVQ_RDNA4_NWARPS;
return 8;

Preferably just hard-code the value here if you're only going to use it once.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

Performance
GPU Model Microbatch size Test t/s b7981 t/s 49a5ff4 Speedup
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1 pp512 70.14 80.06 1.14
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2 pp512 122.80 122.91 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 4 pp512 188.20 188.61 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 8 pp512 216.41 216.35 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 16 pp512 639.09 638.93 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1 pp512 58.07 56.31 0.97
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2 pp512 100.47 100.66 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 4 pp512 153.34 153.55 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 8 pp512 214.63 214.73 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 16 pp512 352.89 355.22 1.01
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1 pp512 60.28 58.25 0.97
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2 pp512 103.50 103.34 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 4 pp512 153.28 153.41 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 8 pp512 215.80 215.21 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 16 pp512 346.00 344.06 0.99
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1 pp512 50.57 49.64 0.98
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2 pp512 90.73 89.83 0.99
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 4 pp512 151.71 149.94 0.99
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 8 pp512 185.93 184.69 0.99
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 16 pp512 514.54 513.16 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1 pp512 46.78 45.52 0.97
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2 pp512 87.68 86.98 0.99
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 4 pp512 148.63 147.52 0.99
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 8 pp512 186.29 185.38 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 16 pp512 489.76 486.95 0.99
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1 pp512 46.52 46.71 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2 pp512 86.60 86.30 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 4 pp512 145.24 145.07 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 8 pp512 181.58 181.39 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 16 pp512 494.54 492.95 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1 pp512 52.85 50.10 0.95
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2 pp512 95.43 95.00 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 4 pp512 150.92 150.29 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 8 pp512 188.52 187.62 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 16 pp512 522.61 518.89 0.99
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1 pp512 55.67 54.27 0.97
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2 pp512 97.06 97.00 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 4 pp512 150.35 149.81 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 8 pp512 191.45 190.47 0.99
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 16 pp512 509.86 507.05 0.99
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1 pp512 49.49 64.03 1.29
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2 pp512 94.63 94.78 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 4 pp512 170.06 170.01 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 8 pp512 203.66 203.61 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 16 pp512 628.08 626.79 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1 pp512 51.49 67.05 1.30
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2 pp512 99.37 99.20 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 4 pp512 182.30 182.48 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 8 pp512 219.00 218.72 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 16 pp512 661.87 660.83 1.00
RX 9060 XT llama 8B Q2_K_S 1 pp512 69.75 77.87 1.12
RX 9060 XT llama 8B Q2_K_S 2 pp512 107.22 107.19 1.00
RX 9060 XT llama 8B Q2_K_S 4 pp512 129.41 129.70 1.00
RX 9060 XT llama 8B Q2_K_S 8 pp512 146.56 146.28 1.00
RX 9060 XT llama 8B Q2_K_S 16 pp512 360.87 363.50 1.01
RX 9060 XT llama 8B Q3_K_S 1 pp512 51.52 46.11 0.89
RX 9060 XT llama 8B Q3_K_S 2 pp512 84.76 85.56 1.01
RX 9060 XT llama 8B Q3_K_S 4 pp512 120.35 121.11 1.01
RX 9060 XT llama 8B Q3_K_S 8 pp512 147.36 148.26 1.01
RX 9060 XT llama 8B Q3_K_S 16 pp512 485.02 487.35 1.00
RX 9060 XT llama 8B Q4_0 1 pp512 49.64 64.76 1.30
RX 9060 XT llama 8B Q4_0 2 pp512 95.14 95.15 1.00
RX 9060 XT llama 8B Q4_0 4 pp512 170.71 170.92 1.00
RX 9060 XT llama 8B Q4_0 8 pp512 211.05 211.15 1.00
RX 9060 XT llama 8B Q4_0 16 pp512 616.15 616.62 1.00
RX 9060 XT llama 8B Q4_1 1 pp512 46.99 59.73 1.27
RX 9060 XT llama 8B Q4_1 2 pp512 89.59 89.46 1.00
RX 9060 XT llama 8B Q4_1 4 pp512 162.67 162.74 1.00
RX 9060 XT llama 8B Q4_1 8 pp512 223.31 223.04 1.00
RX 9060 XT llama 8B Q4_1 16 pp512 610.70 610.28 1.00
RX 9060 XT llama 8B Q4_K_S 1 pp512 49.88 61.28 1.23
RX 9060 XT llama 8B Q4_K_S 2 pp512 92.23 92.21 1.00
RX 9060 XT llama 8B Q4_K_S 4 pp512 133.56 133.04 1.00
RX 9060 XT llama 8B Q4_K_S 8 pp512 153.28 152.79 1.00
RX 9060 XT llama 8B Q4_K_S 16 pp512 565.65 565.03 1.00
RX 9060 XT llama 8B Q5_0 1 pp512 44.30 55.15 1.24
RX 9060 XT llama 8B Q5_0 2 pp512 84.67 84.47 1.00
RX 9060 XT llama 8B Q5_0 4 pp512 151.75 151.54 1.00
RX 9060 XT llama 8B Q5_0 8 pp512 205.13 204.33 1.00
RX 9060 XT llama 8B Q5_0 16 pp512 533.45 532.06 1.00
RX 9060 XT llama 8B Q5_1 1 pp512 42.12 51.39 1.22
RX 9060 XT llama 8B Q5_1 2 pp512 78.93 79.07 1.00
RX 9060 XT llama 8B Q5_1 4 pp512 145.33 145.71 1.00
RX 9060 XT llama 8B Q5_1 8 pp512 242.83 243.41 1.00
RX 9060 XT llama 8B Q5_1 16 pp512 438.08 436.87 1.00
RX 9060 XT llama 8B Q5_K_S 1 pp512 44.36 53.94 1.22
RX 9060 XT llama 8B Q5_K_S 2 pp512 82.47 82.24 1.00
RX 9060 XT llama 8B Q5_K_S 4 pp512 129.45 128.58 0.99
RX 9060 XT llama 8B Q5_K_S 8 pp512 150.63 149.63 0.99
RX 9060 XT llama 8B Q5_K_S 16 pp512 572.24 571.04 1.00
RX 9060 XT llama 8B Q6_K 1 pp512 39.59 47.14 1.19
RX 9060 XT llama 8B Q6_K 2 pp512 75.40 75.56 1.00
RX 9060 XT llama 8B Q6_K 4 pp512 127.89 126.52 0.99
RX 9060 XT llama 8B Q6_K 8 pp512 162.67 161.35 0.99
RX 9060 XT llama 8B Q6_K 16 pp512 455.31 454.03 1.00
RX 9060 XT llama 8B Q8_0 1 pp512 33.23 38.17 1.15
RX 9060 XT llama 8B Q8_0 2 pp512 62.00 62.14 1.00
RX 9060 XT llama 8B Q8_0 4 pp512 115.66 115.55 1.00
RX 9060 XT llama 8B Q8_0 8 pp512 190.20 190.01 1.00
RX 9060 XT llama 8B Q8_0 16 pp512 465.75 465.82 1.00

In my testing this change is not beneficial for all data types, please add a corresponding check. Also, you may want to test whether you can squeeze out a bit more performance for batch sizes 2-8 with the same change since those are relevant for speculative decoding and batched inference.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I forgot: I tested the code like this in case you want to replicate it:

for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k_s q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./bench --model models/opt/${mn}-${q}.gguf -r 1 -fa 1 -n 0 -ub "1-16*2" --progress -o sql|sqlite3 llama-bench.sqlite; sleep 10; done
py scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite -b b7981|tee bench.txt

Comment on lines +92 to +112
// Returns true for quantization types that benefit from nwarps=8 on RDNA4.
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
// pressure and lookup table contention at higher thread counts.
static constexpr __host__ __device__ bool mmvq_rdna4_nwarps8_beneficial(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return true;
default:
return false;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only used once, so just move this logic into calc_nwarps.

@JoursBleu
Copy link
Copy Markdown
Contributor Author

@IMbackK Thanks for the review!

  1. MMVQ_RDNA4_ROWS block removed — you were right, since rows_per_block=1 is the default return value of calc_rows_per_block(), the RDNA4-specific block was pointless. Removed both the macro and the block.

  2. Wave size — agreed, all RDNA devices are wave32-only via HIP. The current code doesn't touch wave size logic; it only changes nwarps in calc_nwarps().

@JohannesGaessler
Copy link
Copy Markdown
Contributor

As is outlined in the contributing guidelines:

It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).

@JoursBleu JoursBleu force-pushed the pr/mmvq-rdna4 branch 2 times, most recently from dad0789 to 2df360d Compare February 26, 2026 07:40
@JoursBleu
Copy link
Copy Markdown
Contributor Author

@JohannesGaessler
We followed the suggestion and tested performance under different quantization methods, added the corresponding checks, and moved the code into calc_nwarps.
At the same time, we added RDNA3 optimization and tested it on the 7900, also with the corresponding checks.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

On Strix Halo the changes seem to be detrimental:

GPU Model Microbatch size Test t/s master t/s 2df360d Speedup
Radeon 8060S Graphics llama 8B IQ1_S - 1.5625 bpw 1 pp512 66.60 66.89 1.00
Radeon 8060S Graphics llama 8B IQ2_S - 2.5 bpw 1 pp512 40.91 40.94 1.00
Radeon 8060S Graphics llama 8B IQ2_XS - 2.3125 bpw 1 pp512 42.16 41.96 1.00
Radeon 8060S Graphics llama 8B IQ2_XXS - 2.0625 bpw 1 pp512 43.09 43.20 1.00
Radeon 8060S Graphics llama 8B IQ3_S - 3.4375 bpw 1 pp512 40.46 40.31 1.00
Radeon 8060S Graphics llama 8B IQ3_S mix - 3.66 bpw 1 pp512 40.17 39.05 0.97
Radeon 8060S Graphics llama 8B IQ3_XS - 3.3 bpw 1 pp512 39.86 39.74 1.00
Radeon 8060S Graphics llama 8B IQ3_XXS - 3.0625 bpw 1 pp512 40.09 39.87 0.99
Radeon 8060S Graphics llama 8B IQ4_NL - 4.5 bpw 1 pp512 44.42 42.42 0.95
Radeon 8060S Graphics llama 8B IQ4_XS - 4.25 bpw 1 pp512 47.31 48.13 1.02
Radeon 8060S Graphics llama 8B Q2_K_S 1 pp512 57.34 56.84 0.99
Radeon 8060S Graphics llama 8B Q3_K_S 1 pp512 44.00 44.00 1.00
Radeon 8060S Graphics llama 8B Q4_0 1 pp512 46.00 42.57 0.93
Radeon 8060S Graphics llama 8B Q4_1 1 pp512 41.82 40.74 0.97
Radeon 8060S Graphics llama 8B Q4_K_S 1 pp512 37.87 36.44 0.96
Radeon 8060S Graphics llama 8B Q5_0 1 pp512 39.97 39.76 0.99
Radeon 8060S Graphics llama 8B Q5_1 1 pp512 33.98 37.42 1.10
Radeon 8060S Graphics llama 8B Q5_K_S 1 pp512 33.74 33.74 1.00
Radeon 8060S Graphics llama 8B Q6_K 1 pp512 32.60 31.87 0.98
Radeon 8060S Graphics llama 8B Q8_0 1 pp512 26.93 27.07 1.00

Please revert the RDNA3 changes unless you can post the data based on which you think they're beneficial.

@JoursBleu
Copy link
Copy Markdown
Contributor Author

@JohannesGaessler Thanks for pointing out the Strix Halo regression. We reproduced it on a local device — nwarps=1 is already optimal, higher values degrade monotonically.

The nwarps tuning is a latency-hiding optimization that works on discrete GPUs where memory bandwidth is underutilized. On Strix Halo with shared LPDDR5X, bandwidth is already near saturation at nwarps=1, so extra warps just add reduction overhead with no benefit.

Strix Halo (RDNA3.5) should be excluded from the RDNA3 nwarps=8 table.

MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA3,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MMVQ_PARAMETERS_RDNA3,
MMVQ_PARAMETERS_RDNA3_0,

Comment on lines +71 to +72
#elif defined(RDNA3) && !defined(RDNA3_5)
return MMVQ_PARAMETERS_RDNA3;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#elif defined(RDNA3) && !defined(RDNA3_5)
return MMVQ_PARAMETERS_RDNA3;
#elif defined(RDNA3_0)
return MMVQ_PARAMETERS_RDNA3_0;

Define an RDNA3_0 macro that can be used in the same, consistent way as in host code and use that instead.

@JoursBleu
Copy link
Copy Markdown
Contributor Author

@JohannesGaessler
Renamed MMVQ_PARAMETERS_RDNA3 → MMVQ_PARAMETERS_RDNA3_0 and added an RDNA3_0 device macro in hip.h (defined as RDNA3 && !RDNA3_5), consistent with the host-side GGML_CUDA_CC_IS_RDNA3_0 convention.

@IMbackK IMbackK self-requested a review March 3, 2026 09:49
@IMbackK IMbackK self-assigned this Mar 3, 2026
@JohannesGaessler
Copy link
Copy Markdown
Contributor

As I said before, please post the data based on which you are making these changes. If you don't I will require someone else with access to an RDNA3 GPU to benchmark these changes.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Mar 3, 2026

@JohannesGaessler i self assigned this with the purpose of benchmarking it on my machine (which now contains a gfx1100 device), @JoursBleu please also provide the benchmarks you performed.

@JoursBleu
Copy link
Copy Markdown
Contributor Author

Hi @JohannesGaessler, we tested the 20 quantized model types listed above on W7900, and the results are shown below.

Group Model Type master t/s pr/19478 t/s Change
Whitelist llama 7B IQ4_NL - 4.5 bpw 110.20 112.14 +1.76%
Whitelist llama 7B Q4_0 111.87 113.85 +1.77%
Whitelist llama 7B Q4_1 104.83 107.08 +2.15%
Whitelist llama 7B Q4_K - Small 92.87 95.42 +2.74%
Whitelist llama 7B Q5_0 95.86 98.98 +3.26%
Whitelist llama 7B Q5_1 80.95 93.99 +16.12%
Whitelist llama 7B Q6_K 81.15 85.11 +4.88%
Whitelist llama 7B Q8_0 73.15 74.80 +2.26%
Non-whitelist llama 7B IQ1_S - 1.5625 bpw 149.64 149.36 -0.19%
Non-whitelist llama 7B IQ2_S - 2.5 bpw 95.12 94.55 -0.61%
Non-whitelist llama 7B IQ2_XS - 2.3125 bpw 99.17 98.84 -0.33%
Non-whitelist llama 7B IQ2_XXS - 2.0625 bpw 101.30 101.13 -0.17%
Non-whitelist llama 7B IQ3_S - 3.4375 bpw 93.83 93.25 -0.62%
Non-whitelist llama 7B IQ3_S mix - 3.66 bpw 92.29 92.16 -0.14%
Non-whitelist llama 7B IQ3_XS - 3.3 bpw 93.26 93.04 -0.23%
Non-whitelist llama 7B IQ3_XXS - 3.0625 bpw 93.70 94.05 +0.37%
Non-whitelist llama 7B IQ4_XS - 4.25 bpw 116.61 115.74 -0.75%
Non-whitelist llama 7B Q2_K - Small 133.23 133.25 +0.02%
Non-whitelist llama 7B Q3_K - Small 100.38 100.32 -0.06%
Non-whitelist llama 7B Q5_K - Small 87.10 86.57 -0.61%

@JoursBleu
Copy link
Copy Markdown
Contributor Author

Hi @JohannesGaessler @IMbackK, just checking to confirm — the benchmark data for the W7900 has been posted.

@IMbackK, could you run the benchmarks for the gfx1100? I'd be happy to assist you if needed.

And let me know if there are any remaining concerns. Thanks!

@JohannesGaessler JohannesGaessler merged commit 617db24 into ggml-org:master Mar 15, 2026
77 of 78 checks passed
@meven3000
Copy link
Copy Markdown

meven3000 commented Mar 19, 2026

Hello, took me a day to narrow this down. However, on RDNA4 x2 R9700 and 3x R9700 , depending ROCM variables, this is leading to anywhere between 10%-17%. TG reduction for Qwen 3.5 120b Q4.

Tested 8354 of which TPS was back at 40.

@IMbackK
Copy link
Copy Markdown
Collaborator

IMbackK commented Mar 19, 2026

x2 and 3x?

@meven3000
Copy link
Copy Markdown

x2 and 3x?

x3 Radeon AI Pro R9700. AKA not ideal setup but works. Due to Motherboard having. 12x DMMS so limited to 3x PCIE5 X16 and need to change somthing..

@JoursBleu
Copy link
Copy Markdown
Contributor Author

Hello, took me a day to narrow this down. However, on RDNA4 x2 R9700 and 3x R9700 , depending ROCM variables, this is leading to anywhere between 10%-17%. TG reduction for Qwen 3.5 120b Q4.

Tested 8354 of which TPS was back at 40.

Hi @meven3000, thank you for your feedback. I will try to reproduce and fix this issue.

Ethan-a2 pushed a commit to Ethan-a2/llama.cpp that referenced this pull request Mar 20, 2026
…org#19478)

* mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1)

* mmvq: add dedicated RDNA3 parameter table

* mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table
@JoursBleu
Copy link
Copy Markdown
Contributor Author

JoursBleu commented Mar 21, 2026

Thanks for the feedback @meven3000! We did further testing and found the regression is not caused by multi-GPU parallelism, but by MoE models.

Qwen3.5-35B-A3B Q4_K_M — single GPU TG

GPU base (t/s) #19478 (t/s) regression
R9700 (RDNA4) 76.92 73.76 -4.1%
W7900 (RDNA3) 77.03 68.05 -11.7%

Qwen3.5-122B-A10B Q4_K_M — 4x GPU TG

GPU base (t/s) #19478 (t/s) regression
R9700 (RDNA4) 38.72 34.74 -10.3%
W7900D (RDNA3) 32.89 29.80 -9.4%

Qwen2.5-72B Q4_K_M (dense) — 4x GPU TG

GPU base (t/s) #19478 (t/s) regression
R9700 (RDNA4) 10.42 10.63 +2.0%
W7900D (RDNA3) 11.54 11.46 -0.7%

We have an initial fix in #20831. More comprehensive testing is in progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants