88typedef void (*cpy_kernel_t )(const char * cx, char * cdst);
99
1010template <cpy_kernel_t cpy_1>
11- static __global__ void cpy_f32_f16 (const char * cx, char * cdst_direct, const int ne,
12- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14- const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
11+ static __global__ void cpy_flt (const char * cx, char * cdst_direct, const int ne,
12+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14+ const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
1515 const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
1616
1717 if (i >= ne) {
@@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
139139#endif
140140}
141141
142- static void ggml_cpy_f16_f32_cuda (
142+ template <typename src_t , typename dst_t >
143+ static void ggml_cpy_flt_cuda (
143144 const char * cx, char * cdst, const int ne,
144145 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
145146 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
146147
147148 const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
148- cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150- }
151-
152- static void ggml_cpy_f32_f32_cuda (
153- const char * cx, char * cdst, const int ne,
154- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156-
157- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
158- cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
159- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160- }
161-
162- static void ggml_cpy_f32_bf16_cuda (
163- const char * cx, char * cdst, const int ne,
164- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
165- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
166-
167- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
168- cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
169- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
170- }
171-
172- static void ggml_cpy_f32_f16_cuda (
173- const char * cx, char * cdst, const int ne,
174- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
175- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
176-
177- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
178- cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
149+ cpy_flt<cpy_1_flt<src_t , dst_t >><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
179150 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
180151}
181152
@@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
307278 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
308279}
309280
310- static void ggml_cpy_f16_f16_cuda (
311- const char * cx, char * cdst, const int ne,
312- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
313- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
314-
315- const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
316- cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
317- (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318- }
319-
320281void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321282 const int64_t ne = ggml_nelements (src0);
322283 GGML_ASSERT (ne == ggml_nelements (src1));
@@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
372333 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
373334 }
374335 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
375- ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
336+ ggml_cpy_flt_cuda< float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376337 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
377- ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
338+ ggml_cpy_flt_cuda< float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
378339 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
379- ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340+ ggml_cpy_flt_cuda< float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
380341 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
381342 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
382343 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
403364 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
404365 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405366 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406- ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
368+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
369+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407370 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408- ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
371+ ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
372+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
373+ ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
374+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
375+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
376+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
377+ ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409378 } else {
410379 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
411380 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
430399 if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
431400 return nullptr ;
432401 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
433- return (void *) cpy_f32_f16<cpy_1_f32_f32 >;
402+ return (void *) cpy_flt<cpy_1_flt< float , float > >;
434403 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
435- return (void *) cpy_f32_f16<cpy_1_f32_bf16 >;
404+ return (void *) cpy_flt<cpy_1_flt< float , nv_bfloat16> >;
436405 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
437- return (void *) cpy_f32_f16<cpy_1_f32_f16 >;
406+ return (void *) cpy_flt<cpy_1_flt< float , half> >;
438407 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439408 return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
440409 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458427 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459428 return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460429 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
430+ return (void *) cpy_flt<cpy_1_flt<half, half>>;
431+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
432+ return (void *) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
462433 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463- return (void *) cpy_f32_f16<cpy_1_f16_f32>;
434+ return (void *) cpy_flt<cpy_1_flt<half, float >>;
435+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
436+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
437+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
438+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
439+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
440+ return (void *) cpy_flt<cpy_1_flt<nv_bfloat16, float >>;
464441 } else {
465442 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
466443 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments