Skip to content

Commit bd562fe

Browse files
authored
cuda : use fast copy when src and dst are of different type and contiguous (ggml-org#16789)
* use fast copy when src and dst are contiguous and same shape * use int64_t ne and ignore shape
1 parent bbac6a2 commit bd562fe

File tree

1 file changed

+69
-11
lines changed

1 file changed

+69
-11
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112
cpy_blck(cx + x_offset, cdst + dst_offset);
113113
}
114114

115+
template<typename src_t, typename dst_t>
116+
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
117+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
118+
119+
if (i >= ne) {
120+
return;
121+
}
122+
123+
const src_t * x = (const src_t *) cx;
124+
dst_t * dst = (dst_t *) cdst;
125+
126+
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
127+
}
128+
129+
template<typename src_t, typename dst_t>
130+
static void ggml_cpy_flt_contiguous_cuda(
131+
const char * cx, char * cdst, const int64_t ne,
132+
cudaStream_t stream) {
133+
134+
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
135+
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
136+
(cx, cdst, ne);
137+
}
138+
115139
template<typename src_t, typename dst_t>
116140
static void ggml_cpy_flt_cuda(
117141
const char * cx, char * cdst, const int ne,
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309
char * src0_ddc = (char *) src0->data;
286310
char * src1_ddc = (char *) src1->data;
287311

288-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
312+
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
313+
314+
if (src0->type == src1->type && contiguous_srcs) {
289315
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
290316
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291317
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296322
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
297323
}
298324
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
299-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
325+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300326
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
301-
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
327+
if (contiguous_srcs) {
328+
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
329+
} else {
330+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331+
}
302332
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
303-
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
333+
if (contiguous_srcs) {
334+
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
335+
} else {
336+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337+
}
304338
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
305339
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306340
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327361
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
328362
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329363
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
330-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
364+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331365
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
332-
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
366+
if (contiguous_srcs) {
367+
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
368+
} else {
369+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
370+
}
333371
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
334-
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
372+
if (contiguous_srcs) {
373+
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
374+
} else {
375+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376+
}
335377
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
336378
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337379
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
338-
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
380+
if (contiguous_srcs) {
381+
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
382+
} else {
383+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
384+
}
339385
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
340-
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
386+
if (contiguous_srcs) {
387+
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
388+
} else {
389+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
390+
}
341391
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
342-
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
392+
if (contiguous_srcs) {
393+
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
394+
} else {
395+
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
396+
}
343397
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
344-
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
398+
if (contiguous_srcs) {
399+
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
400+
} else {
401+
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+
}
345403
} else {
346404
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
347405
ggml_type_name(src0->type), ggml_type_name(src1->type));

0 commit comments

Comments
 (0)