Skip to content

Conversation

@am17an
Copy link
Collaborator

@am17an am17an commented Nov 4, 2025

Possibly supersede #16813.

This PR adds support to run concurrent CUDA streams on single GPU setups.
At the moment this only targets the Q, K, V branch. I feel this is the "correct" approach in case the Q, K, V tensors are of different types/not in the same place in memory. The downside is that this approach doesn't come for free and there's some complexity involved, but I'm not an expert at the ggml graph and I feel it could be simplified.

Currently this is hidden by an env variable flag. To run you can use GGML_CUDA_GRAPH_OPT=1

TG Performance gain is more than the previous PR (1-9% gain depending on the model/GPU), probably because we parallelize MUL_MAT + NORM + ROPE rather than just MUL_MAT. At the moment we leave some performance on the table where we don't fuse operations in the parallel streams themselves (e.g. MUL_MAT + BIAS, RMS_NORM + MUL etc.), I couldn't find a simple enough way to enable fusion there.

Performance details:

5090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 10554.76 10572.61 1.00
gpt-oss 20B MXFP4 MoE tg128 359.67 382.52 1.06
llama 8B Q8_0 pp512 13948.06 13903.74 1.00
llama 8B Q8_0 tg128 166.44 174.08 1.05
qwen3 8B F16 pp512 14120.38 14170.43 1.00
qwen3 8B F16 tg128 97.21 100.39 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7361.89 7367.90 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.84 334.19 1.09

4090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 9551.59 9549.55 1.00
gpt-oss 20B MXFP4 MoE tg128 263.77 271.84 1.03
llama 8B Q8_0 pp512 12111.67 12105.00 1.00
llama 8B Q8_0 tg128 105.57 107.57 1.02
qwen3 8B F16 pp512 10437.09 10414.37 1.00
qwen3 8B F16 tg128 58.99 59.79 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7246.49 7246.52 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.75 268.12 1.08

And just for comparison, this is without fusing ops inside a stream

5090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 10570.11 10548.35 1.00
gpt-oss 20B MXFP4 MoE tg128 359.37 373.34 1.04
llama 8B Q8_0 pp512 13905.82 13913.88 1.00
llama 8B Q8_0 tg128 166.30 174.15 1.05
qwen3 8B F16 pp512 14108.08 14116.40 1.00
qwen3 8B F16 tg128 97.20 99.69 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7358.84 7354.03 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.36 325.89 1.06

4090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 9547.67 9545.65 1.00
gpt-oss 20B MXFP4 MoE tg128 263.74 270.22 1.02
llama 8B Q8_0 pp512 12071.11 12041.15 1.00
llama 8B Q8_0 tg128 105.57 107.61 1.02
qwen3 8B F16 pp512 10409.54 10385.86 1.00
qwen3 8B F16 tg128 59.00 59.64 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7254.97 7240.68 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.53 263.10 1.06

TODO:

  • Enable fusion within a stream
  • Add tests?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 4, 2025
@JohannesGaessler
Copy link
Collaborator

Sorry, I wanted to tell you this but I forgot: a long time ago I tried something similar, see #4719 . There the performance did not improve, I think the reason was the lack of CUDA graphs to reduce the overhead.

@am17an
Copy link
Collaborator Author

am17an commented Nov 4, 2025

Yeah, I think CUDA graphs are essential for this to work (hence this PR only looks at batch_size=1)

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 6, 2025

Minimal changes to make this work on hip:

diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index d3153f430..28bafb84e 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -25,6 +25,7 @@
 #include <cfloat>
 #include <cstdio>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #if defined(GGML_USE_HIP)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f69b99b2a..f0df4a9a9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3514,7 +3514,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 
                         for (int i = 1; i <= concurrent_event->n_streams; ++i) {
                             cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
-                            CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
+                            cudaStreamWaitEvent(stream, concurrent_event->fork_event);
                         }
 
                         is_concurrent_event_active = true;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 890c10364..b7d6edf7f 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -105,7 +105,7 @@
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamPerThread hipStreamPerThread
 #define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaStreamWaitEvent hipStreamWaitEvent
 #define cudaGraphExec_t hipGraphExec_t
 #define cudaGraphNode_t hipGraphNode_t
 #define cudaKernelNodeParams hipKernelNodeParams

If used for real, cudaStreamWaitEvent error needs to handled of course
hipEventCreateWithFlags is also nodiscard which needs to be handled

with -DGGML_HIP_GRAPHS=On and GGML_CUDA_GRAPH_OPT=1 in env this seams performance neutral on mi100:

Model Test t/s b6955 t/s pr Speedup
llama 13B Q8_0 tg128 29.07 29.06 1.00

@am17an
Copy link
Collaborator Author

am17an commented Nov 6, 2025

The almost exact same numbers make me think that this change is not launching the streams. I would expect a shift in performance either for the worse or the better.

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 6, 2025

yeah ill run a trace on it later.

@am17an am17an marked this pull request as ready for review November 8, 2025 06:03
@am17an am17an requested a review from slaren as a code owner November 8, 2025 06:03
@am17an
Copy link
Collaborator Author

am17an commented Nov 9, 2025

It turns out there was a bug in finding the correct node to fork on because of RMS fusion, so what was happening was the Q mul-mat was still running on stream 0, and only after it would finish would we launch the rest of the streams.

Secondly, this is still without fusing nodes within a stream, I could make fusion work by re-ordering back the graph to it's original shape, I think the best solution here would be to signal to the ggml-graph that these nodes are required later and their memory should not be re-used, is there a way to do this?

Also as long as the Q, K, V stuff per layer lives on the same device, this should also work for the multi-GPU case too, but that is not addressed in this PR

On a 5090:

Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg32 285.58 ± 2.98
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg64 266.76 ± 0.16
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg128 264.19 ± 0.09
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg32 382.39 ± 0.16
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg64 370.78 ± 0.11
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg128 367.91 ± 0.18
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg32 85.41 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg64 83.95 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg128 83.60 ± 0.05

After:

model size params backend ngl fa test t/s
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg32 304.13 ± 0.36
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg64 281.68 ± 0.06
qwen3moe 30B.A3B Q4_K - Medium 17.28 GiB 30.53 B CUDA 99 1 tg128 276.83 ± 0.28
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg32 391.99 ± 0.45
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg64 380.56 ± 0.14
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 1 tg128 375.88 ± 0.05
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg32 86.32 ± 0.03
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg64 84.86 ± 0.01
qwen3 8B F16 15.26 GiB 8.19 B CUDA 99 1 tg128 84.44 ± 0.01
Verified with nsys profile image

@am17an am17an force-pushed the fused-qkv-stream branch 3 times, most recently from 12d5f82 to d3a8d93 Compare November 10, 2025 05:17
@am17an
Copy link
Collaborator Author

am17an commented Nov 10, 2025

@ggerganov would you mind testing this on your DGX spark? I want to see if the low memory bandwidth GPUs benefit from this change

@jeffbolznv
Copy link
Collaborator

Secondly, this is still without fusing nodes within a stream, I could make fusion work by re-ordering back the graph to it's original shape, I think the best solution here would be to signal to the ggml-graph that these nodes are required later and their memory should not be re-used, is there a way to do this?

I'm not really clear on what the problem is here you're trying to solve. If the order is: MUL_MAT+ADD+MUL_MAT+ADD+MUL_MAT+ADD, then you have the nodes conveniently consecutive (for fusion), the intermediate MUL_MAT outputs aren't needed and the ADDs will all have different outputs. This is the order ggml-vulkan will use and it gets both fusion and concurrency.

@am17an
Copy link
Collaborator Author

am17an commented Nov 10, 2025

My problem is that the buffer gets reused in this case. The graph assumes serial execution, and thinks the first mul-mats buffer is no longer required. (Assume the no fusion case for now)

@jeffbolznv
Copy link
Collaborator

If you're not doing fusion, then you'd want graph_optimize to reorder these to MUL_MAT+MUL_MAT+MUL_MAT+ADD+ADD+ADD. Then the MUL_MAT results will stay live until the ADDs.

@am17an
Copy link
Collaborator Author

am17an commented Nov 10, 2025

Yeah that's what the current re-order does. But that doesn't allow for fusion. I don't want these two things to be intertwined. Ideally want something that just lets me extend the lifetime for a particular output till a certain node

@jeffbolznv
Copy link
Collaborator

IMO they are fundamentally intertwined. The code that detects fusions looks for specific sequences of operations, and graph_optimize should generate or preserve those sequences. If the backend supports fusing something, then graph_optimize should make them consecutive both to make the fusion logic simpler and to shorten the lifetime of transient allocations.

@am17an
Copy link
Collaborator Author

am17an commented Nov 10, 2025

I think it will be good to isolate these two behaviours. If you see the graph above it can launch a concurrent graph from the mul-mat till set rows. We don't have fusion for that entire sequence, and reasoning about which output stays alive would involve inspecting the graph in any case.

Secondly fusion is a common source of bugs in the cuda backend, I don't want to add another layer of complexity on top of it.

@ORippler

This comment was marked as outdated.

@am17an
Copy link
Collaborator Author

am17an commented Nov 11, 2025

Did you pass the env flag?

@ORippler
Copy link
Contributor

ORippler commented Nov 11, 2025

Did you pass the env flag?

Why of course not 🫠

image

Updated numbers

+ ./scripts/compare-llama-bench.py -b b8595b16e69e3029e06be3b8f6635f9812b2bc3f -c d3a8d93a04d4cb8cc8daa62d12e6763a2f4c0221 --tool llama-bench -i llama-bench.sqlite
Model Test t/s b8595b1 t/s d3a8d93 Speedup
gpt-oss 120B MXFP4 MoE tg32 43.42 43.86 1.01
gpt-oss 120B MXFP4 MoE tg64 49.96 50.20 1.00
gpt-oss 120B MXFP4 MoE tg128 50.95 51.13 1.00
gpt-oss 120B MXFP4 MoE tg256 50.86 51.10 1.00
gpt-oss 120B MXFP4 MoE tg512 50.44 50.86 1.01
gpt-oss 20B MXFP4 MoE tg32 80.20 81.24 1.01
gpt-oss 20B MXFP4 MoE tg64 80.81 81.45 1.01
gpt-oss 20B MXFP4 MoE tg128 80.97 81.52 1.01
gpt-oss 20B MXFP4 MoE tg256 80.98 81.55 1.01
gpt-oss 20B MXFP4 MoE tg512 80.57 81.18 1.01
llama 3B Q4_0 tg32 103.85 105.21 1.01
llama 3B Q4_0 tg64 104.79 106.39 1.02
llama 3B Q4_0 tg128 104.77 106.43 1.02
llama 3B Q4_0 tg256 104.79 106.50 1.02
llama 3B Q4_0 tg512 103.83 105.49 1.02
qwen3moe 30B.A3B Q3_K_S tg32 94.83 96.21 1.01
qwen3moe 30B.A3B Q3_K_S tg64 95.72 96.97 1.01
qwen3moe 30B.A3B Q3_K_S tg128 96.05 97.27 1.01
qwen3moe 30B.A3B Q3_K_S tg256 96.12 97.40 1.01
qwen3moe 30B.A3B Q3_K_S tg512 95.19 96.42 1.01

@ORippler
Copy link
Contributor

My problem is that the buffer gets reused in this case.

If my current understanding of ggml is correct, we should be able to get the same behavior (fusion + concurrency) on both vulkan + cuda, as Q should go out of life after flash-attention (and K + V go out of life after being inserted into the KV-cache). Have we root-caused this?

@IMbackK
Copy link
Collaborator

IMbackK commented Nov 12, 2025

@am17an please cherry-pick in bc93319 to fix the hip build. I can confirm that multiple streams are successfully launching kernels at the same time on hip, the mi100 performance is indeed just unchanged. On the other hand rx6800xt sees a small perf decrease:

Model Test t/s master t/s pr + bc93319 Speedup
llama 7B Q4_0 tg128 93.65 90.75 0.97

But that dosent matter ofc as long as the pr remains hidden behind an envvar.

@JohannesGaessler
Copy link
Collaborator

Sorry for the long radio silence, what is the current status of this PR? Is there still something to be done or do you want to move towards merging it?

@jeffbolznv
Copy link
Collaborator

If I understand correctly I would have to do an fork and join at every fusion (i.e. MUL_MAT + ADD, and then another at RMS_NORM+MUL) to keep the outputs alive. I don't agree with this approach because the graph-optimize function needs to know what fusion operations are available and it increases the synchronization (join) points, and this is to simply keep the outputs alive and not for any other reason.

I don't quite follow what you're saying here, but maybe this is what you mean. Let's say the original graph is:

attn-norm, QMul, QAdd, KMul, KAdd, VMul, VAdd, QNorm, VNorm, QRope, KRope, attn

With this order, you could fuse each of QMul+QAdd, KMul+KAdd, and VMul+VAdd, and run those three fusions on separate streams, and join before the Norms. I don't think there would be a memory allocation/aliasing problem. This is basically what ggml-vulkan will do today.

However, if you want to extend the duration of the streams so that all Q work is on a stream, all K work is on a stream, all V work is on a stream, and only do a single join right before the final attn, then yes there would be a problem with memory reuse (e.g. QAdd's destination might be reused for VNorm's destination, for example). This really has nothing to do with fusion, this is an issue with concurrency not being represented in the graph and would occur regardless of fusion.

@am17an
Copy link
Collaborator Author

am17an commented Nov 13, 2025

However, if you want to extend the duration of the streams so that all Q work is on a stream, all K work is on a stream, all V work is on a stream, and only do a single join right before the final attn, then yes there would be a problem with memory reuse (e.g. QAdd's destination might be reused for VNorm's destination, for example) This really has nothing to do with fusion, this is an issue with concurrency not being represented in the graph and would occur regardless of fusion.

Yes, that's exactly what I want to do with this PR at least as we only have to orchestrate only one event with cudaEventWait

Sorry for the long radio silence, what is the current status of this PR? Is there still something to be done or do you want to move towards merging it?

Let me clean up the code a bit and we can look at merging. Since this is behind an env flag for now it should not cause problems. When it's better tested we can look at enabling by default

@am17an

This comment was marked as outdated.

@ggerganov
Copy link
Member

@am17an Should I compare master vs this branch + GGML_CUDA_GRAPH_OPT=1?

@am17an
Copy link
Collaborator Author

am17an commented Nov 13, 2025

Yes with the latest change it should be different from what @ORippler tested earlier

@ggerganov

This comment has been minimized.

@am17an

This comment has been minimized.

@am17an
Copy link
Collaborator Author

am17an commented Nov 13, 2025

I'm sorry, I don't know how the PPL reported ok, there was a bug in launching streams. I fixed it now and results are more believable. For this run I manually tested llama-cli on a few prompts and they seem ok.

5090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 10554.76 10572.61 1.00
gpt-oss 20B MXFP4 MoE tg128 359.67 382.52 1.06
llama 8B Q8_0 pp512 13948.06 13903.74 1.00
llama 8B Q8_0 tg128 166.44 174.08 1.05
qwen3 8B F16 pp512 14120.38 14170.43 1.00
qwen3 8B F16 tg128 97.21 100.39 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7361.89 7367.90 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.84 334.19 1.09

4090:

Model Test t/s mxfp4 t/s 0fa5630c0 Speedup
gpt-oss 20B MXFP4 MoE pp512 9551.59 9549.55 1.00
gpt-oss 20B MXFP4 MoE tg128 263.77 271.84 1.03
llama 8B Q8_0 pp512 12111.67 12105.00 1.00
llama 8B Q8_0 tg128 105.57 107.57 1.02
qwen3 8B F16 pp512 10437.09 10414.37 1.00
qwen3 8B F16 tg128 58.99 59.79 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7246.49 7246.52 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.75 268.12 1.08

And just for comparison, this is without fusing ops inside a stream

5090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 10570.11 10548.35 1.00
gpt-oss 20B MXFP4 MoE tg128 359.37 373.34 1.04
llama 8B Q8_0 pp512 13905.82 13913.88 1.00
llama 8B Q8_0 tg128 166.30 174.15 1.05
qwen3 8B F16 pp512 14108.08 14116.40 1.00
qwen3 8B F16 tg128 97.20 99.69 1.03
qwen3moe 30B.A3B Q4_K_M pp512 7358.84 7354.03 1.00
qwen3moe 30B.A3B Q4_K_M tg128 306.36 325.89 1.06

4090:

Model Test t/s mxfp4 t/s e50fbde39 Speedup
gpt-oss 20B MXFP4 MoE pp512 9547.67 9545.65 1.00
gpt-oss 20B MXFP4 MoE tg128 263.74 270.22 1.02
llama 8B Q8_0 pp512 12071.11 12041.15 1.00
llama 8B Q8_0 tg128 105.57 107.61 1.02
qwen3 8B F16 pp512 10409.54 10385.86 1.00
qwen3 8B F16 tg128 59.00 59.64 1.01
qwen3moe 30B.A3B Q4_K_M pp512 7254.97 7240.68 1.00
qwen3moe 30B.A3B Q4_K_M tg128 248.53 263.10 1.06

@ggerganov

This comment has been minimized.

@am17an
Copy link
Collaborator Author

am17an commented Nov 13, 2025

Not sure why PP would be affected in your case, perhaps I need to rebase?

@ggerganov
Copy link
Member

I think for now let's ignore the DGX Spark numbers that I posted. I am observing some large variance between runs (even on master) and I am not sure atm what is the root cause.

Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort that has gone in this PR so far! Left some comments, will try to finish tomorrow

Comment on lines +3675 to +3737
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();

if (!enable_graph_optimization) {
return;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();
if (!enable_graph_optimization) {
return;
}
static bool enable_cgraph_optimization = [] {
const char * env = getenv("GGML_CUDA_CGRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();
if (!enable_cgraph_optimization) {
return;
}

This makes it clearer we are not talking about optimizing CUDA Graphs, but rather ggml_cgraph objects

Copy link
Collaborator Author

@am17an am17an Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the correct name because the function unfortunately named ggml_backend_cuda_graph_optimize, which has nothing to do with CUDA Graphs. I don't know how to reconcile these two at the moment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would tend to disagree, and would propose renaming ggml_backend_cuda_graph_optimize to ggml_backend_cuda_cgraph_optimize instead.

A backend's interface is defined fully in ggml-backend-impl.h, and that is what the cuda backend needs to adhere to (and does). Since the cuda backend has to handle ggml_cgraph and CUDA Graphs internally, I personally think moving forward it makes sense to explicitly encode what we are handling in which function, as this saves the look to the function declaration. Consequentially, I would also like to rename ggml_backend_cuda_graph_compute to ggml_backend_cuda_cgraph_compute. This makes the above suggestion not strictly tied to this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about for now we keep this, and change it later when your points are addressed (possibly by yourself) in the larger ggml scheme of things?

}
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * node = cgraph->nodes[node_idx]->src[src_idx];
//TODO: check why nrows > 1 fails, probably related to CUDA graphs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: how does this failure manifest itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

garbled outputs, I'm not really sure what happens

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side I could not repro garbled outputs for
GGML_CUDA_GRAPH_OPT=1 ./build-x64-linux-gcc-debug/bin/llama-cli -m /mnt/share/gguf/gpt-oss-20b-mxfp4.gguf -p "What is the Capital of Sweden? Please be very elaborative and jokey in your answer."
on commit 6562b7797 when reducing the test to if (node && !is_empty(node)).

Please specify a repro so we/I can try to root-cause (I am afraid of another heuristic that we do not fully understand, akin to the batch-size heuristic for disabling CUDA Graph). If it does not repro on your side either, I would say we enable this for pre-fill phase also (if it gives perf improvements)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repro is with qwen3-30b. I also can't repro it with gpt-oss. However, perf improvements won't come in prefill unless we can launch the streams concurrently which happens solely because of CUDA graphs at the moment. In the original graph (which we now come back to enable fusion) has the DFS form (Q things first, K next and V last).

If you can root-cause the repro, we can disable fusion within a stream in prefill phase (i.e. not revert to the orig graph), that would interleave execution. From what I notice it can speed up things about 4-5% too, but at the cost of higher peak memory usage

for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * node = cgraph->nodes[node_idx]->src[src_idx];
//TODO: check why nrows > 1 fails, probably related to CUDA graphs
if (node && !is_empty(node) && ggml_nrows(node) <= 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (node && !is_empty(node) && ggml_nrows(node) <= 1) {
if (node && !is_empty(node) && ggml_nrows(node) == 1) {

Do we have nodes with 0 rows? Thought we always have at least 1 element so we can multiply & divide safely

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it doesn't but this check is more like ggml_nrows <= X where X we need probably can configure after it works for nrows > 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, thanks again for the tremendous effort that went into this PR, which gives a lot of perf! While the main logic is solid for QKV projections (and awesome to see we managed to achieve concurrency + fusion for the CUDA backend as well), I still have some concerns revolving around maintainability. Some recommendations/what I think we should address before merging this:

  • Clean up the code (see detailed comments)
  • Add some docs that state the pattern ggml_backend_cuda_graph_optimize matches for to enable concurrency (fan_out == fan_in == 3 + no other work happens between the root/join node that does not belong to one of the 3 branches). Also I feel matching for a different pattern will be difficult with the current implementation, but that's a problem to tackle if we see other patterns where significant perf can be gained.
  • Resolve nrows repro

While I would also love to see some tests, I would not consider this blocking given the effort that has already went into this PR as is + the fact that it is hidden behind a flag for the moment (though we could also consider making it the default to find potential bugs quicker)

Comment on lines +3675 to +3737
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();

if (!enable_graph_optimization) {
return;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would tend to disagree, and would propose renaming ggml_backend_cuda_graph_optimize to ggml_backend_cuda_cgraph_optimize instead.

A backend's interface is defined fully in ggml-backend-impl.h, and that is what the cuda backend needs to adhere to (and does). Since the cuda backend has to handle ggml_cgraph and CUDA Graphs internally, I personally think moving forward it makes sense to explicitly encode what we are handling in which function, as this saves the look to the function declaration. Consequentially, I would also like to rename ggml_backend_cuda_graph_compute to ggml_backend_cuda_cgraph_compute. This makes the above suggestion not strictly tied to this PR.

}
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * node = cgraph->nodes[node_idx]->src[src_idx];
//TODO: check why nrows > 1 fails, probably related to CUDA graphs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side I could not repro garbled outputs for
GGML_CUDA_GRAPH_OPT=1 ./build-x64-linux-gcc-debug/bin/llama-cli -m /mnt/share/gguf/gpt-oss-20b-mxfp4.gguf -p "What is the Capital of Sweden? Please be very elaborative and jokey in your answer."
on commit 6562b7797 when reducing the test to if (node && !is_empty(node)).

Please specify a repro so we/I can try to root-cause (I am afraid of another heuristic that we do not fully understand, akin to the batch-size heuristic for disabling CUDA Graph). If it does not repro on your side either, I would say we enable this for pre-fill phase also (if it gives perf improvements)

for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * node = cgraph->nodes[node_idx]->src[src_idx];
//TODO: check why nrows > 1 fails, probably related to CUDA graphs
if (node && !is_empty(node) && ggml_nrows(node) <= 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

Comment on lines +524 to +525
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
[[maybe_unused]] int stream_no) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'm personally not a fan of adding unused function arguments (it's not only maybe_unused, but simply not used at all)

Copy link
Collaborator Author

@am17an am17an Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually used, but not in this function. new_pool_for_device is static inside ggml_backend_cuda_context which leads to create a new std::unique_ptr for a different stream_no

GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
}
}
prev_i = i;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prev_i = i;
prev_i = i;
GGML_ASSERT((cuda_ctx->curr_stream_no == 0 && !is_concurrent_event_active) ||
(cuda_ctx->curr_stream_no > 0 && is_concurrent_event_active));

Let's also assert only sequential nodes that do not belong to concurrent events are placed onto the "main" stream if we reserve the "main" stream for this use

Copy link
Collaborator Author

@am17an am17an Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is asserted here, I feel it also doesn't tie stream 0 necessarily to the default stream, it is more like - if you are in a concurrent event, you need to find a stream for all your nodes in between

https://github.com/am17an/llama.cpp/blob/fused-qkv-stream/ggml/src/ggml-cuda/ggml-cuda.cu#L3242

@am17an
Copy link
Collaborator Author

am17an commented Nov 14, 2025

Clean up the code (see detailed comments)

✔️

Add some docs that state the pattern ggml_backend_cuda_graph_optimize matches for to enable concurrency (fan_out == fan_in == 3 + no other work happens between the root/join node that does not belong to one of the 3 branches). Also I feel matching for a different pattern will be difficult with the current implementation, but that's a problem to tackle if we see other patterns where significant perf can be gained.

Added some comments. The other pattern which is quite common is ffn_up + gate. That is easily detected by this pattern (though with fan_out = 2). However we already have a pretty solid fusion for that and it would probably not help here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants