From b252d1f34abcb7c1ec4dd19faaac6ba46bfc24b2 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Fri, 26 Sep 2025 14:10:11 +0000 Subject: [PATCH 1/2] Fix Nemotron Nano v2 9B not executing as CUDA Graph on NVIDIA GPUs --- ggml/src/ggml-cuda/cpy.cu | 17 +++++++++++++---- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++++- src/llama-model.cpp | 4 +++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1b763a6289849..e3d8150ca1a53 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -318,9 +318,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; } -#else +#elif !defined(GGML_CUDA_USE_GRAPHS) GGML_UNUSED(disable_indirection_for_this_node); -#endif + // When CUDA graphs are enabled, we must use copy kernels instead of cudaMemcpyAsync even for same-type contiguous tensors. This is because cudaMemcpyAsync cannot be + // used with CUDA graph indirection (the mechanism that allows dynamic pointer updatesin captured graphs). Using cudaMemcpyAsync would force disabling CUDA graphs entirely, + // causing significant performance regression in models like Nemotron Nano v2. if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) @@ -331,7 +333,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg { CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + } else +#endif + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -399,9 +403,14 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { +#if !defined(GGML_CUDA_USE_GRAPHS) + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { return nullptr; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + } else +#endif + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { return (void*) cpy_flt>; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8c8647b147369..ea7c629e28abc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2ae9abb4464fd..e4547d98c13c1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11744,6 +11744,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14698,6 +14699,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14729,7 +14731,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; From 207d2f085a88cdf8e43fb25a6b4dbd33632415b8 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Mon, 29 Sep 2025 09:33:16 +0000 Subject: [PATCH 2/2] fix to ensure test-backend-ops check passes --- ggml/src/ggml-cuda/cpy.cu | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index e3d8150ca1a53..746f43966b84c 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -318,11 +318,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; } -#elif !defined(GGML_CUDA_USE_GRAPHS) +#else GGML_UNUSED(disable_indirection_for_this_node); - // When CUDA graphs are enabled, we must use copy kernels instead of cudaMemcpyAsync even for same-type contiguous tensors. This is because cudaMemcpyAsync cannot be - // used with CUDA graph indirection (the mechanism that allows dynamic pointer updatesin captured graphs). Using cudaMemcpyAsync would force disabling CUDA graphs entirely, - // causing significant performance regression in models like Nemotron Nano v2. +#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) @@ -331,11 +329,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } - } else -#endif - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -403,14 +403,15 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { -#if !defined(GGML_CUDA_USE_GRAPHS) - // Prioritize CUDA graph compatibility over direct memory copy optimization. - // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; - } else -#endif - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { return (void*) cpy_flt>;