@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1010    *dsti = *xi;
1111}
1212
13+ static  __device__  void  cpy_1_f32_bf16 (const  char  * cxi, char  * cdsti) {
14+     const  float  * xi = (const  float  *) cxi;
15+     nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
16+ 
17+     *dsti = *xi;
18+ }
19+ 
1320static  __device__  void  cpy_1_f32_f16 (const  char  * cxi, char  * cdsti) {
1421    const  float  * xi = (const  float  *) cxi;
1522    half * dsti = (half *) cdsti;
@@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda(
386393        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
387394}
388395
396+ static  void  ggml_cpy_f32_bf16_cuda (
397+     const  char  * cx, char  * cdst, const  int  ne,
398+     const  int  ne00, const  int  ne01, const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02,
399+     const  int  nb03, const  int  ne10, const  int  ne11, const  int  ne12, const  int  nb10, const  int  nb11, const  int  nb12, const  int  nb13, cudaStream_t stream, char  ** cdst_indirect, int  & graph_cpynode_index) {
400+ 
401+     const  int  num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
402+     cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>> 
403+         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
404+ }
405+ 
389406static  void  ggml_cpy_f32_f16_cuda (
390407    const  char  * cx, char  * cdst, const  int  ne,
391408    const  int  ne00, const  int  ne01, const  int  ne02, const  int  nb00, const  int  nb01, const  int  nb02,
@@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
581598        CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
582599    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
583600        ggml_cpy_f32_f32_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
601+     } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_BF16) {
602+         ggml_cpy_f32_bf16_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
584603    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
585604        ggml_cpy_f32_f16_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
586605    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q8_0) {
@@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
634653        return  nullptr ;
635654    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
636655        return  (void *) cpy_f32_f16<cpy_1_f32_f32>;
656+     } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_BF16) {
657+         return  (void *) cpy_f32_f16<cpy_1_f32_bf16>;
637658    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
638659        return  (void *) cpy_f32_f16<cpy_1_f32_f16>;
639660    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q8_0) {
0 commit comments