@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112    cpy_blck (cx + x_offset, cdst + dst_offset);
113113}
114114
115+ template <typename  src_t , typename  dst_t >
116+ static  __global__  void  cpy_flt_contiguous (const  char  * cx, char  * cdst, const  int64_t  ne) {
117+     const  int64_t  i = blockDim .x *blockIdx .x  + threadIdx .x ;
118+ 
119+     if  (i >= ne) {
120+         return ;
121+     }
122+ 
123+     const  src_t  * x = (const  src_t  *) cx;
124+     dst_t  *     dst = (dst_t  *) cdst;
125+ 
126+     dst[i] = ggml_cuda_cast<dst_t >(x[i]);
127+ }
128+ 
129+ template <typename  src_t , typename  dst_t >
130+ static  void  ggml_cpy_flt_contiguous_cuda (
131+     const  char  * cx, char  * cdst, const  int64_t  ne,
132+ cudaStream_t stream) {
133+ 
134+     const  int64_t  num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
135+     cpy_flt_contiguous<src_t , dst_t ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>> 
136+         (cx, cdst, ne);
137+ }
138+ 
115139template <typename  src_t , typename  dst_t >
116140static  void  ggml_cpy_flt_cuda (
117141    const  char  * cx, char  * cdst, const  int  ne,
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309    char  * src0_ddc = (char  *) src0->data ;
286310    char  * src1_ddc = (char  *) src1->data ;
287311
288-     if  (src0->type  == src1->type  && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
312+     const  bool  contiguous_srcs = ggml_is_contiguous (src0) && ggml_is_contiguous (src1);
313+ 
314+     if  (src0->type  == src1->type  && contiguous_srcs) {
289315        GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
290316#if  defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291317        if  (src0->type  == GGML_TYPE_F32 || src0->type  == GGML_TYPE_F16) {
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296322            CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
297323        }
298324    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
299-         ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
325+         ggml_cpy_flt_cuda<float , float >            (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300326    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_BF16) {
301-         ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
327+         if  (contiguous_srcs) {
328+             ggml_cpy_flt_contiguous_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
329+         } else  {
330+             ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331+         }
302332    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
303-         ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
333+         if  (contiguous_srcs) {
334+             ggml_cpy_flt_contiguous_cuda<float , half>        (src0_ddc, src1_ddc, ne, main_stream);
335+         } else  {
336+             ggml_cpy_flt_cuda<float , half>        (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337+         }
304338    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q8_0) {
305339        ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306340    } else  if  (src0->type  == GGML_TYPE_Q8_0 && src1->type  == GGML_TYPE_F32) {
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327361    } else  if  (src0->type  == GGML_TYPE_Q5_1 && src1->type  == GGML_TYPE_F32) {
328362        ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329363    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_F16) {
330-         ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
364+         ggml_cpy_flt_cuda<half, half>                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331365    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_BF16) {
332-         ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
366+         if  (contiguous_srcs) {
367+             ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16>  (src0_ddc, src1_ddc, ne, main_stream);
368+         } else  {
369+             ggml_cpy_flt_cuda<half, nv_bfloat16>    (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
370+         }
333371    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_F32) {
334-         ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
372+         if  (contiguous_srcs) {
373+             ggml_cpy_flt_contiguous_cuda<half, float >        (src0_ddc, src1_ddc, ne, main_stream);
374+         } else  {
375+             ggml_cpy_flt_cuda<half, float >          (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376+         }
335377    } else  if  (src0->type  == GGML_TYPE_BF16 && src1->type  == GGML_TYPE_BF16) {
336378        ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337379    } else  if  (src0->type  == GGML_TYPE_BF16 && src1->type  == GGML_TYPE_F16) {
338-         ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
380+         if  (contiguous_srcs) {
381+             ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half>  (src0_ddc, src1_ddc, ne, main_stream);
382+         } else  {
383+             ggml_cpy_flt_cuda<nv_bfloat16, half>    (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
384+         }
339385    } else  if  (src0->type  == GGML_TYPE_BF16 && src1->type  == GGML_TYPE_F32) {
340-         ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
386+         if  (contiguous_srcs) {
387+             ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, main_stream);
388+         } else  {
389+             ggml_cpy_flt_cuda<nv_bfloat16, float >   (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
390+         }
341391    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_I32) {
342-         ggml_cpy_flt_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
392+         if  (contiguous_srcs) {
393+             ggml_cpy_flt_contiguous_cuda<float , int32_t >     (src0_ddc, src1_ddc, ne, main_stream);
394+         } else  {
395+             ggml_cpy_flt_cuda<float , int32_t >       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
396+         }
343397    } else  if  (src0->type  == GGML_TYPE_I32 && src1->type  == GGML_TYPE_F32) {
344-         ggml_cpy_flt_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
398+         if  (contiguous_srcs) {
399+             ggml_cpy_flt_contiguous_cuda<int32_t , float >     (src0_ddc, src1_ddc, ne, main_stream);
400+         } else  {
401+             ggml_cpy_flt_cuda<int32_t , float >       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+         }
345403    } else  {
346404        GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " 
347405                ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments