@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428428 char * src0_ddc = (char *) src0->data ;
429429 char * src1_ddc = (char *) src1->data ;
430430
431- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
431+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
432+ GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
433+ CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
434+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432435 ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433436 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434437 ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
449452 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
450453 ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
451454 } else {
452- fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
455+ GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " , __func__,
453456 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
454- GGML_ABORT (" fatal error" );
455457 }
456458}
457459
@@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461463}
462464
463465void * ggml_cuda_cpy_fn (const ggml_tensor * src0, ggml_tensor * src1) {
464- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465- return (void *) cpy_f32_f16<cpy_1_f32_f32>;
466+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
467+ return nullptr ;
468+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
469+ return (void *) cpy_f32_f16<cpy_1_f32_f32>;
466470 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
471+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
468472 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469- return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
473+ return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470474 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471- return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
475+ return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472476 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473- return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
477+ return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474478 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475- return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
479+ return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476480 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477- return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
481+ return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478482 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479- return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
483+ return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480484 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
485+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
482486 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483- return (void *) cpy_f32_f16<cpy_1_f16_f32>;
487+ return (void *) cpy_f32_f16<cpy_1_f16_f32>;
484488 } else {
485- fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
489+ GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " , __func__,
486490 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
487- GGML_ABORT (" fatal error" );
488491 }
489492}
0 commit comments