Skip to content

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Oct 13, 2025

I see speedups in my 3090, but not so much on a 4090. I suspect it due to better integer division hardware on newer cards, but I did not find any documentation to confirm.

on 3090:

Model Test t/s master t/s cuda_mmvf_fastdiv Speedup
lfm2moe 8B.A1B BF16 tg32 117.54 128.09 1.09
lfm2moe 8B.A1B BF16 tg64 120.75 127.56 1.06
lfm2moe 8B.A1B BF16 tg128 121.43 128.15 1.06

on 4090:

Model Test t/s master t/s cuda_mmvf_fastdiv Speedup
lfm2moe 8B.A1B F16 tg32 139.51 140.07 1.00
lfm2moe 8B.A1B F16 tg64 139.41 139.74 1.00
lfm2moe 8B.A1B F16 tg128 139.35 139.57 1.00

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 13, 2025
@am17an am17an changed the title CUDA: use fast + ggml_cuda_mad for mmvf CUDA: use fastdiv + ggml_cuda_mad for mmvf Oct 13, 2025
@JohannesGaessler
Copy link
Collaborator

I can confirm a speedup, though a smaller one. Presumably it will depend on the model.

GPU Model Test t/s 477a66b t/s 8898040 Speedup
MI50 llama 1B BF16 tg128 150.16 150.15 1.00
MI50 llama 1B F16 tg128 149.20 150.19 1.01
MI50 llama 1B all F32 tg128 93.51 93.58 1.00
P40 llama 1B BF16 tg128 109.53 110.07 1.00
P40 llama 1B F16 tg128 109.01 109.70 1.01
P40 llama 1B all F32 tg128 59.18 59.29 1.00
RTX 3090 llama 1B BF16 tg128 269.84 271.69 1.01
RTX 3090 llama 1B F16 tg128 270.05 272.03 1.01
RTX 3090 llama 1B all F32 tg128 152.44 153.25 1.01
RTX 4090 llama 1B BF16 tg128 316.85 317.81 1.00
RTX 4090 llama 1B F16 tg128 316.88 317.98 1.00
RTX 4090 llama 1B all F32 tg128 174.09 174.31 1.00
RX 6800 llama 1B BF16 tg128 94.65 96.46 1.02
RX 6800 llama 1B F16 tg128 94.41 96.61 1.02
RX 6800 llama 1B all F32 tg128 80.23 80.64 1.00
RX 9060 XT llama 1B BF16 tg128 98.26 99.33 1.01
RX 9060 XT llama 1B F16 tg128 99.47 99.96 1.00
RX 9060 XT llama 1B all F32 tg128 57.23 57.74 1.01

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mul_mat_vec_f should always pass float2, half2, or nv_bfloat162 to ggml_cuda_mad and then let that function decide how to do the calculation. For example, on I think Hopper and Blackwell there are mixed-precision instructions that can be used (possibly in a future PR) and there definitely are such instructions on AMD GPUs (which are already supported).

@am17an am17an force-pushed the cuda_mmvf_fastdiv branch from 23f2ccc to 9d74b8f Compare October 14, 2025 04:55
@am17an am17an requested a review from slaren as a code owner October 14, 2025 04:55
@am17an am17an force-pushed the cuda_mmvf_fastdiv branch 2 times, most recently from 7560a47 to ec9a51c Compare October 14, 2025 05:20
@am17an am17an requested a review from IMbackK October 14, 2025 05:22
@am17an am17an force-pushed the cuda_mmvf_fastdiv branch from ec9a51c to e1afe75 Compare October 14, 2025 05:32
@am17an
Copy link
Collaborator Author

am17an commented Oct 14, 2025

Sorry I am not able to fix the HIP builds

@JohannesGaessler
Copy link
Collaborator

For now keep the problematic code in mmvf.cu as-is for HIP with a comment briefly explaining the problem.

@am17an am17an force-pushed the cuda_mmvf_fastdiv branch from 6dce339 to d6c71e9 Compare October 14, 2025 09:27
@IMbackK
Copy link
Collaborator

IMbackK commented Oct 14, 2025

Sorry I am not able to fix the HIP builds

ill take a look

@JohannesGaessler
Copy link
Collaborator

ill take a look

Would be appreciated, otherwise I would have tried to fix this myself. My preferred approach would be to merge this PR as-is and to fix the HIP issues in a follow-up PR. Is that fine with both of you?

@IMbackK
Copy link
Collaborator

IMbackK commented Oct 14, 2025

sure, yes

@JohannesGaessler JohannesGaessler merged commit 1ee9d0b into ggml-org:master Oct 14, 2025
65 of 70 checks passed
ddh0 added a commit to ddh0/llama.cpp that referenced this pull request Oct 14, 2025
* cuda : remove legacy copy-op pointer indirection code (ggml-org#16485)

* remove legacy copy-op pointer indirection code

* further removal of copy-op indirection code

* renamed check_node_graph_compatibility_and_refresh_copy_ops function

* CUDA: add fp kernel for larger batch size MoE (ggml-org#16512)

* CUDA: kernel for larger batch sizes for MoE

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* fixup

* tests

* Move mmq_ids_helper to mmid

* cleanup

* Remove redundant checks

* CUDA: use fastdiv + ggml_cuda_mad for mmvf (ggml-org#16557)

* CUDA: use fastdiv + ggml_cuda_mad for mmvf

* use bf16 directly + fix formatting

* Add exception for HIP code

* CUDA: enable FA for FP32 KV cache (ggml-org#16546)

* vulkan: Improve build time for MSVC (ggml-org#16545)

Enable CMP0147 so custom build steps (invoking vulkan-shader-gen) are run in parallel.

Enable /MP so source files are compiled in parallel.

* vulkan: Support FA with K/V in F32 (ggml-org#16543)

* CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (ggml-org#16577)

* vulkan: Add ACC_TYPE_VEC2 implementation (ggml-org#16203)

Signed-off-by: Stefan Savic <[email protected]>
Co-authored-by: Stefan Savic <[email protected]>

* metal : avoid using Metal's gpuAddress property (ggml-org#16576)

* metal : avoid using Metal's gpuAddress property

* metal : fix rope kernels buffer check

---------

Signed-off-by: Stefan Savic <[email protected]>
Co-authored-by: Anav Prasad <[email protected]>
Co-authored-by: Aman Gupta <[email protected]>
Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Jeff Bolz <[email protected]>
Co-authored-by: SavicStefan <[email protected]>
Co-authored-by: Stefan Savic <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
yael-works pushed a commit to yael-works/llama.cpp that referenced this pull request Oct 15, 2025
* CUDA: use fastdiv + ggml_cuda_mad for mmvf

* use bf16 directly + fix formatting

* Add exception for HIP code
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Oct 15, 2025
* origin/master:
Add server-driven parameter defaults and syncing (ggml-org#16515)
metal: optimise `GGML_OP_SUM` (ggml-org#16559)
server : fix img token logs (ggml-org#16595)
llama-quant: add support for mmproj (ggml-org#16592)
CUDA: Changing the CUDA scheduling strategy to spin (ggml-org#16585)
server : fix mtmd checkpoints (ggml-org#16591)
metal : avoid using Metal's gpuAddress property (ggml-org#16576)
vulkan: Add ACC_TYPE_VEC2 implementation (ggml-org#16203)
CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (ggml-org#16577)
vulkan: Support FA with K/V in F32 (ggml-org#16543)
vulkan: Improve build time for MSVC (ggml-org#16545)
CUDA: enable FA for FP32 KV cache (ggml-org#16546)
CUDA: use fastdiv + ggml_cuda_mad for mmvf (ggml-org#16557)
CUDA: add fp kernel for larger batch size MoE (ggml-org#16512)
cuda : remove legacy copy-op pointer indirection code (ggml-org#16485)
server : dynamic token limit for prompt cache (ggml-org#16560)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants