Skip to content

Commit 96ac5a2

Browse files
authored
cuda : support non-contiguous i32 to i32 copy (ggml-org#17326)
* support non-contiguous i32 to i32 copy * add tests * rename cpy_flt to cpy_scalar and reindent params
1 parent bc809e9 commit 96ac5a2

File tree

4 files changed

+92
-49
lines changed

4 files changed

+92
-49
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
212212
}
213213

214214
template<typename src_t, typename dst_t>
215-
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
215+
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
216216
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
217217
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
1212
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
1313

1414
template <cpy_kernel_t cpy_1>
15-
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
16-
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
17-
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
18-
const int nb12, const int nb13) {
15+
static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
16+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
17+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
18+
const int nb12, const int nb13) {
1919
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
2020

2121
if (i >= ne) {
@@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
4040
}
4141

4242
template <typename T>
43-
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
43+
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
4444
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
4545
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
4646
const int nb12, const int nb13) {
@@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
166166
}
167167

168168
template<typename src_t, typename dst_t>
169-
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
169+
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
170170
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
171171

172172
if (i >= ne) {
@@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
180180
}
181181

182182
template<typename src_t, typename dst_t>
183-
static void ggml_cpy_flt_contiguous_cuda(
183+
static void ggml_cpy_scalar_contiguous_cuda(
184184
const char * cx, char * cdst, const int64_t ne,
185185
cudaStream_t stream) {
186186

187187
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
188-
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
188+
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
189189
(cx, cdst, ne);
190190
}
191191

192192
template<typename src_t, typename dst_t, bool transposed = false>
193-
static void ggml_cpy_flt_cuda(
193+
static void ggml_cpy_scalar_cuda(
194194
const char * cx, char * cdst, const int ne,
195195
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
196196
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
@@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda(
212212
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
213213
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
214214
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
215-
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
215+
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
216216
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
217217
} else {
218218
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
219-
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
219+
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
220220
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
221221
}
222222
}
@@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
399399
}
400400
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
401401
if (can_be_transposed) {
402-
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+
ggml_cpy_scalar_cuda<float, float, true>
403+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
403404
} else {
404-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
405+
ggml_cpy_scalar_cuda<float, float>
406+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
405407
}
406408
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
407409
if (contiguous_srcs) {
408-
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
410+
ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
411+
(src0_ddc, src1_ddc, ne, main_stream);
409412
} else {
410-
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
413+
ggml_cpy_scalar_cuda<float, nv_bfloat16>
414+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
411415
}
412416
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
413417
if (contiguous_srcs) {
414-
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
418+
ggml_cpy_scalar_contiguous_cuda<float, half>
419+
(src0_ddc, src1_ddc, ne, main_stream);
415420
} else {
416-
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
421+
ggml_cpy_scalar_cuda<float, half>
422+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
417423
}
418424
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
419-
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
425+
ggml_cpy_f32_q8_0_cuda
426+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
420427
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
421-
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
428+
ggml_cpy_q8_0_f32_cuda
429+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
422430
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
423-
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
431+
ggml_cpy_f32_q4_0_cuda
432+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
424433
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
425-
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
426-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
434+
ggml_cpy_q4_0_f32_cuda
435+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
427436
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
428-
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
437+
ggml_cpy_f32_q4_1_cuda
438+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
429439
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
430-
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
431-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
440+
ggml_cpy_q4_1_f32_cuda
441+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
432442
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
433-
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
443+
ggml_cpy_f32_q5_0_cuda
444+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
434445
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
435-
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
436-
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
446+
ggml_cpy_q5_0_f32_cuda
447+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
437448
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
438-
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
449+
ggml_cpy_f32_iq4_nl_cuda
450+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
439451
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
440-
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
452+
ggml_cpy_f32_q5_1_cuda
453+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
441454
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
442-
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
455+
ggml_cpy_q5_1_f32_cuda
456+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
443457
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
444458
if (can_be_transposed) {
445-
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
459+
ggml_cpy_scalar_cuda<half, half, true>
460+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
446461
} else {
447-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
462+
ggml_cpy_scalar_cuda<half, half>
463+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
448464
}
449465
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
450466
if (contiguous_srcs) {
451-
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
467+
ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
468+
(src0_ddc, src1_ddc, ne, main_stream);
452469
} else {
453-
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
470+
ggml_cpy_scalar_cuda<half, nv_bfloat16>
471+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
454472
}
455473
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
456474
if (contiguous_srcs) {
457-
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
475+
ggml_cpy_scalar_contiguous_cuda<half, float>
476+
(src0_ddc, src1_ddc, ne, main_stream);
458477
} else {
459-
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
478+
ggml_cpy_scalar_cuda<half, float>
479+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
460480
}
461481
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
462482
if (can_be_transposed) {
463-
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
483+
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
484+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
464485
} else {
465-
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
486+
ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
487+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
466488
}
467489
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
468490
if (contiguous_srcs) {
469-
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
491+
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
492+
(src0_ddc, src1_ddc, ne, main_stream);
470493
} else {
471-
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
494+
ggml_cpy_scalar_cuda<nv_bfloat16, half>
495+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
472496
}
473497
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
474498
if (contiguous_srcs) {
475-
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
499+
ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
500+
(src0_ddc, src1_ddc, ne, main_stream);
501+
} else {
502+
ggml_cpy_scalar_cuda<nv_bfloat16, float>
503+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
504+
}
505+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
506+
if (can_be_transposed) {
507+
ggml_cpy_scalar_cuda<int32_t, int32_t, true>
508+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
476509
} else {
477-
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
510+
ggml_cpy_scalar_cuda<int32_t, int32_t>
511+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
478512
}
479513
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
480514
if (contiguous_srcs) {
481-
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
515+
ggml_cpy_scalar_contiguous_cuda<float, int32_t>
516+
(src0_ddc, src1_ddc, ne, main_stream);
482517
} else {
483-
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
518+
ggml_cpy_scalar_cuda<float, int32_t>
519+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
484520
}
485521
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
486522
if (contiguous_srcs) {
487-
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
523+
ggml_cpy_scalar_contiguous_cuda<int32_t, float>
524+
(src0_ddc, src1_ddc, ne, main_stream);
488525
} else {
489-
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
526+
ggml_cpy_scalar_cuda<int32_t, float>
527+
(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
490528
}
491529
} else {
492530
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,6 +4115,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
41154115
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
41164116
return true;
41174117
}
4118+
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
4119+
return true;
4120+
}
41184121
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
41194122
return true;
41204123
}

tests/test-backend-ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6953,9 +6953,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69536953
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
69546954
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
69556955
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6956+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
6957+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
69566958
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
69576959

6958-
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
6960+
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
69596961
for (bool use_view_slice : { true, false }) {
69606962
for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
69616963
{2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {

0 commit comments

Comments
 (0)