Skip to content

Commit 9cbb916

Browse files
authored
further deduplication
1 parent 1860cf9 commit 9cbb916

File tree

3 files changed

+17
-69
lines changed

3 files changed

+17
-69
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,11 @@
44

55
template<typename src_t, typename dst_t>
66
static __device__ __forceinline__ void convert_to_flt(const src_t * src, dst_t * dst) {
7-
*dst = float(*src);
8-
}
9-
10-
template<typename src_t>
11-
static __device__ __forceinline__ void convert_to_f16(const src_t * src, half * dst) {
12-
*dst = __float2half(*src);
13-
}
14-
15-
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
16-
*dst = *src;
7+
if constexpr (std::is_same_v<src_t, dst_t>) {
8+
*dst = *src;
9+
} else {
10+
*dst = float(*src);
11+
}
1712
}
1813

1914
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
@@ -228,12 +223,3 @@ template<typename src_t, typename dst_t>
228223
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
229224
convert_to_flt((const src_t *)cxi, (dst_t *)cdsti);
230225
}
231-
232-
template<typename src_t>
233-
static __device__ void cpy_1_to_f16(const char * cxi, char * cdsti) {
234-
convert_to_f16((const src_t *)cxi, (half *)cdsti);
235-
}
236-
237-
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
238-
convert_f16_f16((const half *)cxi, (half *)cdsti);
239-
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

1010
template <cpy_kernel_t cpy_1>
1111
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
12-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14-
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
12+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
13+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
14+
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
1515
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
1616

1717
if (i >= ne) {
@@ -150,17 +150,6 @@ static void ggml_cpy_flt_cuda(
150150
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
151151
}
152152

153-
template<typename src_t>
154-
static void ggml_cpy_to_f16_cuda(
155-
const char * cx, char * cdst, const int ne,
156-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
157-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
158-
159-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
160-
cpy_flt<cpy_1_to_f16<src_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
161-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
162-
}
163-
164153
static void ggml_cpy_f32_q8_0_cuda(
165154
const char * cx, char * cdst, const int ne,
166155
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -289,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
289278
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
290279
}
291280

292-
static void ggml_cpy_f16_f16_cuda(
293-
const char * cx, char * cdst, const int ne,
294-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
295-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
296-
297-
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
298-
cpy_flt<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
299-
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
300-
}
301-
302281
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
303282
const int64_t ne = ggml_nelements(src0);
304283
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -358,7 +337,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
358337
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
359338
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
360339
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
361-
ggml_cpy_to_f16_cuda<float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
340+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
362341
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
363342
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
364343
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -385,16 +364,15 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
385364
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
386365
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
387366
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
388-
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
367+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
389368
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
390369
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
391370
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
392371
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
393372
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
394-
// Pure copy, doesn't need its own BF16 function
395-
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
373+
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
396374
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
397-
ggml_cpy_to_f16_cuda<nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
375+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
398376
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
399377
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
400378
} else {
@@ -425,7 +403,7 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
425403
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
426404
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
427405
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
428-
return (void*) cpy_flt<cpy_1_to_f16<float>>;
406+
return (void*) cpy_flt<cpy_1_flt<float, half>>;
429407
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
430408
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
431409
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -449,15 +427,15 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
449427
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
450428
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
451429
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
452-
return (void*) cpy_flt<cpy_1_f16_f16>;
430+
return (void*) cpy_flt<cpy_1_flt<half, half>>;
453431
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
454432
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
455433
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
456434
return (void*) cpy_flt<cpy_1_flt<half, float>>;
457435
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
458-
return (void*) cpy_flt<cpy_1_to_f16<nv_bfloat16>>;
436+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
459437
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
460-
return (void*) cpy_flt<cpy_1_f16_f16>;
438+
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
461439
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
462440
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
463441
} else {

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,7 @@
44
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
55

66
template<typename src_t, typename dst_t>
7-
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
8-
GGML_UNUSED(src_f);
9-
GGML_UNUSED(dst_f);
10-
}
11-
12-
template<>
13-
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
14-
convert_to_f16(src_f, dst_h);
15-
}
16-
17-
template<>
18-
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
19-
convert_to_flt(src_f, dst_b);
20-
}
21-
22-
template<>
23-
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
7+
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
248
convert_to_flt(src_f, dst_f);
259
}
2610

0 commit comments

Comments
 (0)