@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112 cpy_blck (cx + x_offset, cdst + dst_offset);
113113}
114114
115+ template <typename src_t , typename dst_t >
116+ static __global__ void cpy_flt_cont_shape (const char * cx, char * cdst, const int ne) {
117+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
118+
119+ if (i >= ne) {
120+ return ;
121+ }
122+
123+ const src_t * x = (const src_t *) cx;
124+ dst_t * dst = (dst_t *) cdst;
125+
126+ dst[i] = ggml_cuda_cast<dst_t >(x[i]);
127+ }
128+
129+ template <typename src_t , typename dst_t >
130+ static void ggml_cpy_flt_cont_shape_cuda (
131+ const char * cx, char * cdst, const int ne,
132+ cudaStream_t stream) {
133+
134+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
135+ cpy_flt_cont_shape<src_t , dst_t ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
136+ (cx, cdst, ne);
137+ }
138+
115139template <typename src_t , typename dst_t >
116140static void ggml_cpy_flt_cuda (
117141 const char * cx, char * cdst, const int ne,
@@ -285,7 +309,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309 char * src0_ddc = (char *) src0->data ;
286310 char * src1_ddc = (char *) src1->data ;
287311
288- if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
312+ const bool contiguous_srcs = ggml_is_contiguous (src0) && ggml_is_contiguous (src1);
313+ const bool cont_shape_srcs = contiguous_srcs && ggml_are_same_shape (src0, src1);
314+
315+ if (src0->type == src1->type && contiguous_srcs) {
289316 GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
290317#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291318 if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +323,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296323 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
297324 }
298325 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
299- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
326+ ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300327 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
301- ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
328+ if (cont_shape_srcs) {
329+ ggml_cpy_flt_cont_shape_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
330+ } else {
331+ ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
332+ }
302333 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
303- ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
334+ if (cont_shape_srcs) {
335+ ggml_cpy_flt_cont_shape_cuda<float , half> (src0_ddc, src1_ddc, ne, main_stream);
336+ } else {
337+ ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
338+ }
304339 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
305340 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306341 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +362,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327362 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
328363 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329364 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
330- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
365+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331366 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
332- ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
367+ if (cont_shape_srcs) {
368+ ggml_cpy_flt_cont_shape_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
369+ } else {
370+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
371+ }
333372 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
334- ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
373+ if (cont_shape_srcs) {
374+ ggml_cpy_flt_cont_shape_cuda<half, float > (src0_ddc, src1_ddc, ne, main_stream);
375+ } else {
376+ ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
377+ }
335378 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
336379 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337380 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
338- ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
381+ if (cont_shape_srcs) {
382+ ggml_cpy_flt_cont_shape_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
383+ } else {
384+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
385+ }
339386 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
340- ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
387+ if (cont_shape_srcs) {
388+ ggml_cpy_flt_cont_shape_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, main_stream);
389+ } else {
390+ ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
391+ }
341392 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
342- ggml_cpy_flt_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
393+ if (cont_shape_srcs) {
394+ ggml_cpy_flt_cont_shape_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, main_stream);
395+ } else {
396+ ggml_cpy_flt_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
397+ }
343398 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
344- ggml_cpy_flt_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
399+ if (cont_shape_srcs) {
400+ ggml_cpy_flt_cont_shape_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, main_stream);
401+ } else {
402+ ggml_cpy_flt_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
403+ }
345404 } else {
346405 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
347406 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments