Skip to content

Conversation

anavp-nvidia
Copy link
Contributor

In PR 16328, CUDA Graph support for the Nemotron Nano v2 (NemotronH) model was enabled by replacing use of cudaMemcpyAsync with an existing CUDA copy kernel for copy of contiguous tensors. However, that kernel is optimized for non-contiguous tensors.

This PR introduces a CUDA copy kernel for contiguous GGML tensors, which provides a performance improvement of ~3.7% for Nemotron Nano v2 on RTX 5090.

Results (RTX 5090):

Weights: bartowski/nvidia_NVIDIA-Nemotron-Nano-9B-v2-GGUF
Quantization: Q4_K_M

Performance before:

  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       |  99 |  1 |    tg200 @ d100 |        165.50 ± 0.19 |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       |  99 |  1 | pp100+tg200 @ d100 |        174.14 ± 2.02 |

Performance after:

  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       |  99 |  1 |    tg200 @ d100 |        171.66 ± 0.08 |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       |  99 |  1 | pp100+tg200 @ d100 |        180.92 ± 0.38 |

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 8, 2025
@CISC CISC requested a review from JohannesGaessler October 8, 2025 11:22

const int elements_per_thread = 4;
const int threads_needed = (ne_elements + elements_per_thread - 1) / elements_per_thread;
const int num_blocks = max(1, min(65535, (threads_needed + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't work if ne_elements is larger than a certain amount (I think 16726016). We can add an assert here or check whether num_blocks limit can be higher than 65535

const int64_t remaining = ne_elements - base_idx;

if (remaining >= elements_per_thread) {
if (base_idx % 4 == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand base_idx is always a multiple of elements_per_threads which is 4, so this check is not neccessary?


T * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst_direct;

const int elements_per_thread = 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is declared both while launching the kernel and here, perhaps make it a constant like CUDA_CPY_BLOCK_SIZE

}
}
} else {
for (int j = 0; j < remaining; ++j) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since here remaining is < 4, we can do a unroll like below, but I doubt it will have any affect on performance

        #pragma unroll
        for (int j = 0; j < 4; ++j) {
            size_t i = base + (size_t)j;
            if (i < ne_elements) cdst[i] = cx[i];
        }

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is this kernel does not properly check for memory alignment. When you copy a float4 this is done as a single, 16 byte transfer. However, if the pointer is not aligned to 16 byte this will result in a crash.

I would suggest you look at

// Maximum number of bytes that can be copied in a single instruction.
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
#ifdef GGML_USE_HIP
return 16;
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
return 16;
#else
return 8;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // GGML_USE_HIP
}
and
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (alignment != 0) {
static_assert(nbytes % alignment == 0, "bad alignment");
}
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
#pragma unroll
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
if constexpr (nb_per_cpy == 1) {
((char *) dst)[i] = ((const char *) src)[i];
} else if constexpr (nb_per_cpy == 2) {
((short *) dst)[i] = ((const short *) src)[i];
} else if constexpr (nb_per_cpy == 4) {
((int *) dst)[i] = ((const int *) src)[i];
} else if constexpr (nb_per_cpy == 8) {
((int2 *) dst)[i] = ((const int2 *) src)[i];
} else if constexpr (nb_per_cpy == 16) {
((int4 *) dst)[i] = ((const int4 *) src)[i];
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
}
}
}
. Using those function I would do the implementation like this:

  1. Add a template parameter for the alignment of the copy.
  2. At runtime, check the alignment of the tensors and run the template specialization with the maximum memory alignment supported by the hardware (I think this will need a utility function to fetch this property in host code).
  3. In the kernel use ggml_cuda_memcpy_1 exactly once per thread to get optimal memory alignment. Avoid using the function with an alignment smaller than the copy size since this will result in suboptimal memory bandwidth.
  4. (Maybe change the kernel to use char * to make it more generally applicable.)

@slaren
Copy link
Member

slaren commented Oct 9, 2025

Is this kernel faster than the previous cudaMemcpyAsync? If the only goal here is to have something to return from ggml_cuda_cpy_fn, this might not be necessary. This function was used in the first implementation of the CUDA graphs support, but I don't think it is used now, and it should be possible to remove it entirely.

@CISC
Copy link
Collaborator

CISC commented Oct 9, 2025

Is this kernel faster than the previous cudaMemcpyAsync? If the only goal here is to have something to return from ggml_cuda_cpy_fn, this might not be necessary. This function was used in the first implementation of the CUDA graphs support, but I don't think it is used now, and it should be possible to remove it entirely.

It's still used:

void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}

@slaren
Copy link
Member

slaren commented Oct 9, 2025

It's still used:

The return value is completely ignored. Even if it wasn't, the reason it was necessary in the first place is because we used ggml_cpy to update the KV cache, but we no longer do that, we use ggml_set_rows now.

@CISC
Copy link
Collaborator

CISC commented Oct 9, 2025

It's still used:

The return value is completely ignored. Even if it wasn't, the reason it was necessary in the first place is because we used ggml_cpy to update the KV cache, but we no longer do that, we use ggml_set_rows now.

Right, you meant removing the function call...

@JohannesGaessler
Copy link
Collaborator

If the in-code comment regarding CUDA graph support is outdated then my opinion is that we should simply use cudaMemcpyAsync.

@anavp-nvidia
Copy link
Contributor Author

While the KV cache update may no longer use ggml_cpy, the Mamba2 layers in the Nemotron Nano v2 model still make use of this operation.

Specifically, it is invoked within the Mamba2 layer around the ssm_scan operation, as shown in the screenshot below.
Screenshot 2025-10-09 163322

My understanding is that when ggml_cpy is called with contiguous tensors and involves pointer indirection for CUDA graphs, we need to perform the copy through a CUDA kernel rather than cudaMemcpyAsync to avoid disabling CUDA Graph.

Please let me know if there's a more suitable or preferred approach to handle this case.

@slaren
Copy link
Member

slaren commented Oct 9, 2025

The pointers in mamba should be the same on every token, so I don't think the indirection is necessary.

@anavp-nvidia
Copy link
Contributor Author

Thanks! I wasn't aware that pointer indirection wasn't required here, appreciate the insight.

I tested this locally by deleting the following section:

if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}

and modifying these lines to always use cudaMemcpyAsync:

if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}

With those changes, CUDA Graph execution ran without any issues, and performance (for Nemotron Nano v2) was as follows:

  Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       | 9999 |  1 |    tg200 @ d100 |        172.05 ± 0.11 |
| nemotron_h 9B Q4_K - Medium    |   6.07 GiB |     8.89 B | CUDA       | 9999 |  1 | pp100+tg200 @ d100 |        181.24 ± 0.81 |

Testing as per contribution guidelines also didn't raise any new issues.

Based on this, it seems safe to remove the copy op pointer indirection code and revert to using cudaMemcpyAsync for contiguous tensor copies.
If you agree with this assessment, I'll close this PR and, open a new one reflecting the above changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants