Skip to content

Commit f3fb5de

Browse files
authored
use fast copy when src and dst are contiguous and same shape
1 parent bbac6a2 commit f3fb5de

File tree

1 file changed

+70
-11
lines changed

1 file changed

+70
-11
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112
cpy_blck(cx + x_offset, cdst + dst_offset);
113113
}
114114

115+
template<typename src_t, typename dst_t>
116+
static __global__ void cpy_flt_cont_shape(const char * cx, char * cdst, const int ne) {
117+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
118+
119+
if (i >= ne) {
120+
return;
121+
}
122+
123+
const src_t * x = (const src_t *) cx;
124+
dst_t * dst = (dst_t *) cdst;
125+
126+
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
127+
}
128+
129+
template<typename src_t, typename dst_t>
130+
static void ggml_cpy_flt_cont_shape_cuda(
131+
const char * cx, char * cdst, const int ne,
132+
cudaStream_t stream) {
133+
134+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
135+
cpy_flt_cont_shape<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
136+
(cx, cdst, ne);
137+
}
138+
115139
template<typename src_t, typename dst_t>
116140
static void ggml_cpy_flt_cuda(
117141
const char * cx, char * cdst, const int ne,
@@ -285,7 +309,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309
char * src0_ddc = (char *) src0->data;
286310
char * src1_ddc = (char *) src1->data;
287311

288-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
312+
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
313+
const bool cont_shape_srcs = contiguous_srcs && ggml_are_same_shape(src0, src1);
314+
315+
if (src0->type == src1->type && contiguous_srcs) {
289316
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
290317
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291318
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +323,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296323
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
297324
}
298325
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
299-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
326+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300327
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
301-
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
328+
if (cont_shape_srcs) {
329+
ggml_cpy_flt_cont_shape_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
330+
} else {
331+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
332+
}
302333
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
303-
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
334+
if (cont_shape_srcs) {
335+
ggml_cpy_flt_cont_shape_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
336+
} else {
337+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
338+
}
304339
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
305340
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306341
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +362,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327362
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
328363
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329364
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
330-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
365+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331366
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
332-
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
367+
if (cont_shape_srcs) {
368+
ggml_cpy_flt_cont_shape_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
369+
} else {
370+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
371+
}
333372
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
334-
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
373+
if (cont_shape_srcs) {
374+
ggml_cpy_flt_cont_shape_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
375+
} else {
376+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
377+
}
335378
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
336379
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337380
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
338-
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
381+
if (cont_shape_srcs) {
382+
ggml_cpy_flt_cont_shape_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
383+
} else {
384+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
385+
}
339386
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
340-
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
387+
if (cont_shape_srcs) {
388+
ggml_cpy_flt_cont_shape_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
389+
} else {
390+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
391+
}
341392
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
342-
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
393+
if (cont_shape_srcs) {
394+
ggml_cpy_flt_cont_shape_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
395+
} else {
396+
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
397+
}
343398
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
344-
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
399+
if (cont_shape_srcs) {
400+
ggml_cpy_flt_cont_shape_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
401+
} else {
402+
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
403+
}
345404
} else {
346405
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
347406
ggml_type_name(src0->type), ggml_type_name(src1->type));

0 commit comments

Comments
 (0)