@@ -149,6 +149,16 @@ static void ggml_cpy_f16_f32_cuda(
149149 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150150}
151151
152+ static void ggml_cpy_bf16_f32_cuda (
153+ const char * cx, char * cdst, const int ne,
154+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156+
157+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
158+ cpy_f32_f16<cpy_1_bf16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
159+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160+ }
161+
152162static void ggml_cpy_f32_f32_cuda (
153163 const char * cx, char * cdst, const int ne,
154164 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -317,6 +327,26 @@ static void ggml_cpy_f16_f16_cuda(
317327 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318328}
319329
330+ static void ggml_cpy_f16_bf16_cuda (
331+ const char * cx, char * cdst, const int ne,
332+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
333+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
334+
335+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
336+ cpy_f32_f16<cpy_1_f16_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
337+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
338+ }
339+
340+ static void ggml_cpy_bf16_f16_cuda (
341+ const char * cx, char * cdst, const int ne,
342+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
343+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
344+
345+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
346+ cpy_f32_f16<cpy_1_bf16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
347+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
348+ }
349+
320350void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321351 const int64_t ne = ggml_nelements (src0);
322352 GGML_ASSERT (ne == ggml_nelements (src1));
@@ -404,8 +434,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
404434 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405435 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406436 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
437+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
438+ ggml_cpy_f16_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407439 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408440 ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
442+ // Pure copy, doesn't need its own BF16 function
443+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
444+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
445+ ggml_cpy_bf16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
446+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
447+ ggml_cpy_bf16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409448 } else {
410449 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
411450 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -458,9 +497,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458497 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459498 return (void *) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460499 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
500+ return (void *) cpy_f32_f16<cpy_1_f16_f16>;
501+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
502+ return (void *) cpy_f32_f16<cpy_1_f16_bf16>;
462503 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463504 return (void *) cpy_f32_f16<cpy_1_f16_f32>;
505+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
506+ return (void *) cpy_f32_f16<cpy_1_bf16_f16>;
507+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
508+ return (void *) cpy_f32_f16<cpy_1_f16_f16>;
509+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
510+ return (void *) cpy_f32_f16<cpy_1_bf16_f32>;
464511 } else {
465512 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
466513 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments