@@ -38,6 +38,25 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
3838 cpy_1 (cx + x_offset, cdst + dst_offset);
3939}
4040
41+ template <typename src_t , typename dst_t >
42+ static __global__ void cpy_flt_contiguous (const char * cx, char * cdst_direct, const int ne,
43+ char ** cdst_indirect, int graph_cpynode_index) {
44+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
45+
46+ if (i >= ne) {
47+ return ;
48+ }
49+
50+ auto dst = (cdst_indirect != nullptr ) ? (dst_t *)cdst_indirect[graph_cpynode_index] : (dst_t *)cdst_direct;
51+ auto src = (const src_t *)cx;
52+
53+ if constexpr (std::is_same_v<dst_t , nv_bfloat16>) {
54+ dst[i] = __float2bfloat16 (src[i]);
55+ } else {
56+ dst[i] = (dst_t )src[i];
57+ }
58+ }
59+
4160static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
4261 float * cdstf = (float *)(cdsti);
4362
@@ -163,6 +182,16 @@ static void ggml_cpy_flt_cuda(
163182 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
164183}
165184
185+ template <typename src_t , typename dst_t >
186+ static void ggml_cpy_flt_contiguous_cuda (
187+ const char * cx, char * cdst, const int ne,
188+ cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
189+
190+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
191+ cpy_flt_contiguous<src_t , dst_t ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
192+ (cx, cdst, ne, cdst_indirect, graph_cpynode_index++);
193+ }
194+
166195static void ggml_cpy_f32_q8_0_cuda (
167196 const char * cx, char * cdst, const int ne,
168197 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -404,6 +433,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
404433 char * src0_ddc = (char *) src0->data ;
405434 char * src1_ddc = (char *) src1->data ;
406435
436+ bool fast_cpy = ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && ggml_are_same_shape (src0, src1);
437+
407438 char ** dest_ptrs_d = nullptr ;
408439 int graph_cpynode_index = -1 ;
409440#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
@@ -429,11 +460,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
429460 }
430461 }
431462 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
463+ if (fast_cpy) {
464+ ggml_cpy_flt_contiguous_cuda<float , float >(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
465+ } else {
466+ ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
467+ }
433468 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
434- ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
469+ if (fast_cpy) {
470+ ggml_cpy_flt_contiguous_cuda<float , nv_bfloat16>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
471+ } else {
472+ ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
473+ }
435474 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
436- ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
475+ if (fast_cpy) {
476+ ggml_cpy_flt_contiguous_cuda<float , half>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
477+ } else {
478+ ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
479+ }
437480 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
438481 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439482 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -505,6 +548,7 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
505548}
506549
507550void * ggml_cuda_cpy_fn (const ggml_tensor * src0, ggml_tensor * src1) {
551+ bool fast_cpy = ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && ggml_are_same_shape (src0, src1);
508552 if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
509553 // Prioritize CUDA graph compatibility over direct memory copy optimization.
510554 // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
@@ -514,11 +558,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
514558 return nullptr ;
515559 }
516560 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
517- return (void *) cpy_flt<cpy_1_flt<float , float >>;
561+ return fast_cpy ? ( void *)cpy_flt_contiguous< float , float > : (void *) cpy_flt<cpy_1_flt<float , float >>;
518562 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
519- return (void *) cpy_flt<cpy_1_flt<float , nv_bfloat16>>;
563+ return fast_cpy ? ( void *)cpy_flt_contiguous< float , nv_bfloat16> : (void *) cpy_flt<cpy_1_flt<float , nv_bfloat16>>;
520564 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
521- return (void *) cpy_flt<cpy_1_flt<float , half>>;
565+ return fast_cpy ? ( void *)cpy_flt_contiguous< float , half> : (void *) cpy_flt<cpy_1_flt<float , half>>;
522566 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
523567 return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
524568 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
0 commit comments