88typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
99
1010template <cpy_kernel_t cpy_1>
11- static __global__ void cpy_f32_f16 (const char * cx, char * cdst_direct, const int ne,
11+ static __global__ void cpy_flt (const char * cx, char * cdst_direct, const int ne,
1212 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1313 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1414 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
@@ -139,53 +139,25 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139139#endif
140140}
141141
142- static void ggml_cpy_f16_f32_cuda (
142+ template <typename src_t , typename dst_t >
143+ static void ggml_cpy_flt_cuda (
143144 const char * cx, char * cdst, const int ne,
144145 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145146 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146147
147148 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
148- cpy_f32_f16<cpy_1_f16_f32 ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149+ cpy_flt<cpy_1_flt< src_t , dst_t > ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149150 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150151}
151152
152- static void ggml_cpy_bf16_f32_cuda (
153+ template <typename src_t >
154+ static void ggml_cpy_to_f16_cuda (
153155 const char * cx, char * cdst, const int ne,
154156 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155157 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156158
157159 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
158- cpy_f32_f16<cpy_1_bf16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
159- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160- }
161-
162- static void ggml_cpy_f32_f32_cuda (
163- const char * cx, char * cdst, const int ne,
164- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166-
167- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
168- cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
169- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170- }
171-
172- static void ggml_cpy_f32_bf16_cuda (
173- const char * cx, char * cdst, const int ne,
174- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176-
177- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
178- cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
179- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180- }
181-
182- static void ggml_cpy_f32_f16_cuda (
183- const char * cx, char * cdst, const int ne,
184- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
185- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
186-
187- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
188- cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
160+ cpy_flt<cpy_1_to_f16<src_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
189161 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
190162}
191163
@@ -323,27 +295,7 @@ static void ggml_cpy_f16_f16_cuda(
323295 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
324296
325297 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
326- cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
327- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
328- }
329-
330- static void ggml_cpy_f16_bf16_cuda (
331- const char * cx, char * cdst, const int ne,
332- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
333- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
334-
335- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
336- cpy_f32_f16<cpy_1_f16_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
337- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
338- }
339-
340- static void ggml_cpy_bf16_f16_cuda (
341- const char * cx, char * cdst, const int ne,
342- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
343- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
344-
345- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
346- cpy_f32_f16<cpy_1_bf16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
298+ cpy_flt<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
347299 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
348300}
349301
@@ -402,11 +354,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
402354 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
403355 }
404356 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
405- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
357+ ggml_cpy_flt_cuda< float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
406358 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
407- ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
359+ ggml_cpy_flt_cuda< float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
408360 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
409- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
361+ ggml_cpy_to_f16_cuda< float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
410362 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
411363 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
412364 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -435,16 +387,16 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
435387 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
436388 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
437389 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
438- ggml_cpy_f16_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
390+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439391 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
440- ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
392+ ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441393 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
442394 // Pure copy, doesn't need its own BF16 function
443395 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
444396 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
445- ggml_cpy_bf16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
397+ ggml_cpy_to_f16_cuda<nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
446398 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
447- ggml_cpy_bf16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
399+ ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
448400 } else {
449401 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
450402 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -469,11 +421,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
469421 if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
470422 return nullptr ;
471423 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
472- return (void *) cpy_f32_f16<cpy_1_f32_f32 >;
424+ return (void *) cpy_flt<cpy_1_flt< float , float > >;
473425 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
474- return (void *) cpy_f32_f16<cpy_1_f32_bf16 >;
426+ return (void *) cpy_flt<cpy_1_flt< float , nv_bfloat16> >;
475427 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
476- return (void *) cpy_f32_f16<cpy_1_f32_f16 >;
428+ return (void *) cpy_flt<cpy_1_to_f16< float > >;
477429 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
478430 return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
479431 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -497,17 +449,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
497449 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
498450 return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
499451 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
500- return (void *) cpy_f32_f16 <cpy_1_f16_f16>;
452+ return (void *) cpy_flt <cpy_1_f16_f16>;
501453 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
502- return (void *) cpy_f32_f16<cpy_1_f16_bf16 >;
454+ return (void *) cpy_flt<cpy_1_flt<half, nv_bfloat16> >;
503455 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
504- return (void *) cpy_f32_f16<cpy_1_f16_f32 >;
456+ return (void *) cpy_flt<cpy_1_flt<half, float > >;
505457 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
506- return (void *) cpy_f32_f16<cpy_1_bf16_f16 >;
458+ return (void *) cpy_flt<cpy_1_to_f16<nv_bfloat16> >;
507459 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
508- return (void *) cpy_f32_f16 <cpy_1_f16_f16>;
460+ return (void *) cpy_flt <cpy_1_f16_f16>;
509461 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
510- return (void *) cpy_f32_f16<cpy_1_bf16_f32 >;
462+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, float > >;
511463 } else {
512464 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
513465 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments