@@ -9,9 +9,9 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99
1010template <cpy_kernel_t cpy_1>
1111static __global__ void cpy_flt (const char * cx, char * cdst_direct, const int ne,
12- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
12+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14+ const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
1515 const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
1616
1717 if (i >= ne) {
@@ -150,17 +150,6 @@ static void ggml_cpy_flt_cuda(
150150 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
151151}
152152
153- template <typename src_t >
154- static void ggml_cpy_to_f16_cuda (
155- const char * cx, char * cdst, const int ne,
156- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
157- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
158-
159- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
160- cpy_flt<cpy_1_to_f16<src_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
161- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
162- }
163-
164153static void ggml_cpy_f32_q8_0_cuda (
165154 const char * cx, char * cdst, const int ne,
166155 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -289,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
289278 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
290279}
291280
292- static void ggml_cpy_f16_f16_cuda (
293- const char * cx, char * cdst, const int ne,
294- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
295- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
296-
297- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
298- cpy_flt<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
299- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
300- }
301-
302281void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
303282 const int64_t ne = ggml_nelements (src0);
304283 GGML_ASSERT (ne == ggml_nelements (src1));
@@ -358,7 +337,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
358337 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
359338 ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
360339 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
361- ggml_cpy_to_f16_cuda <float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340+ ggml_cpy_flt_cuda <float , half > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
362341 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
363342 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
364343 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -385,16 +364,15 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
385364 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
386365 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
387366 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
388- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
389368 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
390369 ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
391370 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
392371 ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
393372 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
394- // Pure copy, doesn't need its own BF16 function
395- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
373+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
396374 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
397- ggml_cpy_to_f16_cuda <nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
375+ ggml_cpy_flt_cuda <nv_bfloat16, half > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
398376 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
399377 ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
400378 } else {
@@ -425,7 +403,7 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
425403 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
426404 return (void *) cpy_flt<cpy_1_flt<float , nv_bfloat16>>;
427405 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
428- return (void *) cpy_flt<cpy_1_to_f16 <float >>;
406+ return (void *) cpy_flt<cpy_1_flt <float , half >>;
429407 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
430408 return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
431409 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -449,15 +427,15 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
449427 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
450428 return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
451429 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
452- return (void *) cpy_flt<cpy_1_f16_f16 >;
430+ return (void *) cpy_flt<cpy_1_flt<half, half> >;
453431 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
454432 return (void *) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
455433 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
456434 return (void *) cpy_flt<cpy_1_flt<half, float >>;
457435 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
458- return (void *) cpy_flt<cpy_1_to_f16 <nv_bfloat16>>;
436+ return (void *) cpy_flt<cpy_1_flt <nv_bfloat16, half >>;
459437 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
460- return (void *) cpy_flt<cpy_1_f16_f16 >;
438+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16> >;
461439 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
462440 return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, float >>;
463441 } else {
0 commit comments