Skip to content

Commit 442076f

Browse files
authored
Allow ggml_cuda_cpy to copy contiguous F32 tensors > INT_MAX
1 parent 2f68ce7 commit 442076f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
279279
const int64_t ne = ggml_nelements(src0);
280280
GGML_ASSERT(ne == ggml_nelements(src1));
281281

282-
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
283-
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
284-
285282
const int64_t ne00 = src0->ne[0];
286283
const int64_t ne01 = src0->ne[1];
287284
const int64_t ne02 = src0->ne[2];
@@ -321,7 +318,15 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
321318
{
322319
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
323320
}
324-
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
321+
return;
322+
}
323+
324+
// cudaMemcpyAsync takes size_t for count, so assert these after
325+
// see: https://github.com/ggml-org/llama.cpp/issues/15049
326+
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
327+
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
328+
329+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
325330
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
326331
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
327332
if (contiguous_srcs) {

0 commit comments

Comments
 (0)