Skip to content

Commit 4162ffe

Browse files
authored
implement bf16 cpy ops and enable bf16 cont
1 parent bf9087f commit 4162ffe

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,22 @@ static __device__ __forceinline__ void convert_f16_f16(const half * src, half *
1818
*dst = *src;
1919
}
2020

21+
static __device__ __forceinline__ void convert_f16_bf16(const half * src, nv_bfloat16 * dst) {
22+
*dst = float(*src);
23+
}
24+
2125
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
2226
*dst = *src;
2327
}
2428

29+
static __device__ __forceinline__ void convert_bf16_f16(const nv_bfloat16 * src, half * dst) {
30+
*dst = __float2half(*src);
31+
}
32+
33+
static __device__ __forceinline__ void convert_bf16_f32(const nv_bfloat16 * src, float * dst) {
34+
*dst = *src;
35+
}
36+
2537
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
2638
if (x <= val[0]) return 0;
2739
if (x >= val[n-1]) return n-1;
@@ -246,6 +258,18 @@ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
246258
convert_f16_f16((const half *)cxi, (half *)cdsti);
247259
}
248260

261+
static __device__ void cpy_1_f16_bf16(const char * cxi, char * cdsti) {
262+
convert_f16_bf16((const half *)cxi, (nv_bfloat16 *)cdsti);
263+
}
264+
249265
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
250266
convert_f16_f32((const half *)cxi, (float *)cdsti);
251267
}
268+
269+
static __device__ void cpy_1_bf16_f16(const char * cxi, char * cdsti) {
270+
convert_bf16_f16((const nv_bfloat16 *)cxi, (half *)cdsti);
271+
}
272+
273+
static __device__ void cpy_1_bf16_f32(const char * cxi, char * cdsti) {
274+
convert_bf16_f32((const nv_bfloat16 *)cxi, (float *)cdsti);
275+
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ static void ggml_cpy_f16_f32_cuda(
149149
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
150150
}
151151

152+
static void ggml_cpy_bf16_f32_cuda(
153+
const char * cx, char * cdst, const int ne,
154+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
155+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
156+
157+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
158+
cpy_f32_f16<cpy_1_bf16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
159+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
160+
}
161+
152162
static void ggml_cpy_f32_f32_cuda(
153163
const char * cx, char * cdst, const int ne,
154164
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -317,6 +327,26 @@ static void ggml_cpy_f16_f16_cuda(
317327
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
318328
}
319329

330+
static void ggml_cpy_f16_bf16_cuda(
331+
const char * cx, char * cdst, const int ne,
332+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
333+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
334+
335+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
336+
cpy_f32_f16<cpy_1_f16_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
337+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
338+
}
339+
340+
static void ggml_cpy_bf16_f16_cuda(
341+
const char * cx, char * cdst, const int ne,
342+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
343+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
344+
345+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
346+
cpy_f32_f16<cpy_1_bf16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
347+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
348+
}
349+
320350
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
321351
const int64_t ne = ggml_nelements(src0);
322352
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -404,8 +434,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
404434
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
405435
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
406436
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
437+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
438+
ggml_cpy_f16_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
407439
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
408440
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
441+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
442+
// Pure copy, doesn't need its own BF16 function
443+
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
444+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
445+
ggml_cpy_bf16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
446+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
447+
ggml_cpy_bf16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
409448
} else {
410449
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
411450
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -458,9 +497,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
458497
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
459498
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
460499
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
461-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
500+
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
501+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
502+
return (void*) cpy_f32_f16<cpy_1_f16_bf16>;
462503
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
463504
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
505+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
506+
return (void*) cpy_f32_f16<cpy_1_bf16_f16>;
507+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
508+
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
509+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
510+
return (void*) cpy_f32_f16<cpy_1_bf16_f32>;
464511
} else {
465512
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
466513
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3287,9 +3287,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32873287
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
32883288
return true;
32893289
}
3290+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_BF16) {
3291+
return true;
3292+
}
32903293
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
32913294
return true;
32923295
}
3296+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3297+
return true;
3298+
}
3299+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_F16) {
3300+
return true;
3301+
}
3302+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_F32) {
3303+
return true;
3304+
}
32933305
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
32943306
return true;
32953307
}
@@ -3370,7 +3382,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33703382
return op->src[0]->ne[1] % 128 == 0;
33713383
}
33723384
case GGML_OP_CONT:
3373-
return op->src[0]->type != GGML_TYPE_BF16;
3385+
return true;
33743386
case GGML_OP_DIAG_MASK_INF:
33753387
return true;
33763388
case GGML_OP_SOFT_MAX:

0 commit comments

Comments
 (0)