Skip to content

Conversation

@zhang-hui-yulo
Copy link

@zhang-hui-yulo zhang-hui-yulo commented Oct 29, 2025

Add RDNA4 support for mmf, just attach the log of test-backend-ops perf.

There are some perf improvement like:
before:
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1): 994 runs - 1376.16 us/run - 100.66 MFLOP/run - �[1;34m 73.15 GFLOPS�[0m
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=32,k=2048,o=1): 250 runs - 4024.50 us/run - 805.31 MFLOP/run - �[1;34m200.10 GFLOPS�[0m
after:
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=4,k=2048,o=1): 7952 runs - 132.58 us/run - 100.66 MFLOP/run - �[1;34m759.26 GFLOPS�[0m
MUL_MAT_ID(type_a=f16,type_b=f32,n_mats=128,n_used=8,b=0,m=768,n=32,k=2048,o=1): 1750 runs - 581.82 us/run - 805.31 MFLOP/run - �[1;34m 1.38 TFLOPS�[0m

Risks:

  • mul_mat_f_cuda might need extra template paramaters to handle tile_A and tile_B size for different AMD arch, so far it's enough as RDNA3, RDNA4 and CDNA3 all have 16x16x16 fp16 mma.
  • I'm not familiar with the shared memory padding in mul_mat_f_cuda, seems like that nbytes_shared_combine is padded for shared memory bank conflict, if so, it also needs cc value to handle different AMD arch as RDNA3 and RDNA4 need different padding column. So far, it's enough as RDNA4 has same bank size and number as NVIDIA Turing and Ampere.
  • No performance data for the real model as I'm not sure which tool can measure it.

Could you give me the step to measure the performance change of a real model? Thank you.

after.txt
before.txt

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 29, 2025
@am17an
Copy link
Collaborator

am17an commented Oct 29, 2025

Could you give me the step to measure the performance change of a real model? Thank you.

You can use llama-bench for that. You would need to use a f16, bf16 or f32 model to check.

I'm not familiar with the shared memory padding in mul_mat_f_cuda, seems like that nbytes_shared_combine is padded for shared memory bank conflict

Yes that's correct

@JohannesGaessler
Copy link
Collaborator

Thank you for this PR. Right now I'm implementing Volta support in parallel. A lot of the issues w.r.t. tile shapes and compilation failures that I've encountered seem to be the same. I ended up using 32x8 + 8x8 -> 32x8 tiles (note that for A and B the lengths are in units of physical 4 byte registers so it's 32x16 + 8x16 -> 32x8 in terms of logical values). I think those tile sizes are also available for AMD WMMA and using them would be simpler. So I would suggest we either merge my PR first so that there is a more similar implementation on master that you can adapt or we try to adapt this PR as-is (I would be fine with either).

@JohannesGaessler
Copy link
Collaborator

(Long-term we should have support for both 32x8 and 16x16 tiles since 32x8 is better for <= 8 tokens but 16x16 needs 20% less shared memory I/O).

@zhang-hui-yulo
Copy link
Author

Thank you for this PR. Right now I'm implementing Volta support in parallel. A lot of the issues w.r.t. tile shapes and compilation failures that I've encountered seem to be the same. I ended up using 32x8 + 8x8 -> 32x8 tiles (note that for A and B the lengths are in units of physical 4 byte registers so it's 32x16 + 8x16 -> 32x8 in terms of logical values). I think those tile sizes are also available for AMD WMMA and using them would be simpler. So I would suggest we either merge my PR first so that there is a more similar implementation on master that you can adapt or we try to adapt this PR as-is (I would be fine with either).

Yep, tile<32,8,half2> is available for RDNA4 and RDNA3 as all AMD wmma instruction use 16x16x16 layout, honestly RDNA doesn't have TF32 support, so I add the dummy function in tile<16,8,float> to make the compiler happy.

CDNA3 shall have 16x16x16 fp16 mfma, but I'm not sure the layout of it's TF32 instruction, let me try to borrow a MI when this PR is finished.

Using which PR shall be submitted first doesn't matter, looks like that my PR doesn't pass the CI on Linux, if I'm able to fix the issue this week and the performance looks good, I will suggest to use my PR first. Or I will suggest to use your PR first.

@JohannesGaessler
Copy link
Collaborator

My PR should be almost ready, we can decide then.

@zhang-hui-yulo
Copy link
Author

My PR should be almost ready, we can decide then.

Got it, you first as I still need to get a Linux env with ROCm 6.1.2 to pass the CI.

@JohannesGaessler
Copy link
Collaborator

PR for Volta support: #16843 . Instead of template specializations that only exist to make the compiler happy I've made it so that the device code never tries to use them in the first place (to avoid accidental misuse).

@zhang-hui-yulo
Copy link
Author

PR for Volta support: #16843 . Instead of template specializations that only exist to make the compiler happy I've made it so that the device code never tries to use them in the first place (to avoid accidental misuse).

I got it, looks similar to my PR, you can submit first as I still need sometime to get a MI308 and benchmark the model.

Honestly, based on my experience, the most compiler errors are triggered by

template <int I_, int J_, typename T>
struct tile {
...
static_assert(I == -1 && J == -1, "template specialization not implemented");
}

Because CUDA and ROCm have different MMA layout, the common code like mmf, mmq and load_genertic will always have compiler trouble.

Also things will be more complicated when try to accomplish flash attention for AMD, as the C matrix layout is column-major, so it needs D = B * A + C instead of D = A * B + C to handle gemm fuston.

Not sure if there is a better way to handle mma tile for different hardware.

@zhang-hui-yulo
Copy link
Author

Hello @JohannesGaessler

May I have the download link of models you are evaluating in #16843 ? Thank you.

I just use https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf but looks like the performance doesn't change much, not sure why, maybe fp16 mmvf handles too many ops.

Best Regards
Hui

@JohannesGaessler
Copy link
Collaborator

I used https://huggingface.co/meta-llama/Meta-Llama-3-8B and https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite which I downloaded in HuggingFace format and then converted myself.

@JohannesGaessler
Copy link
Collaborator

In any case, you should only be seeing a performance difference for an FP16 model, for a q4_0 model the MMF kernel should not be used.

@JohannesGaessler
Copy link
Collaborator

I determined performance using this command

./build/bin/llama-bench --model models/opt/${model_name}-${quantization}.gguf -r 1 -fa 1 -n 0 -ub 1-16 -o sql|sqlite3 llama-bench.sqlite

for the two commits followed by

python3 scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite

@zhang-hui-yulo zhang-hui-yulo marked this pull request as draft October 31, 2025 13:41
@zhang-hui-yulo
Copy link
Author

Got it, thank you for the support, I can see the performance change on deepseek-r1-0528-qwen3-8b.f16.gguf, although the result isn't good, mmf is slower than hipblas, I shall spend sometime to investigate the reason.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants