Skip to content

Commit 851553e

Browse files
YaelGitAccountCISC
andauthored
cuda: add SET operation support (ggml-org#16804)
* feat(cuda): add GGML_OP_SET support Implement CUDA kernel for SET operation with f32 support. All tests passing (14598/14598). * cuda(set): add I32 support; keep F32 * refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update ggml/src/ggml-cuda/set.cu Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 85a7d86 commit 851553e

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "ggml-cuda/upscale.cuh"
5151
#include "ggml-cuda/wkv.cuh"
5252
#include "ggml-cuda/gla.cuh"
53+
#include "ggml-cuda/set.cuh"
5354
#include "ggml-cuda/set-rows.cuh"
5455
#include "ggml-cuda/pad_reflect_1d.cuh"
5556
#include "ggml.h"
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24162417
case GGML_OP_SET_ROWS:
24172418
ggml_cuda_op_set_rows(ctx, dst);
24182419
break;
2420+
case GGML_OP_SET:
2421+
ggml_cuda_op_set(ctx, dst);
2422+
break;
24192423
case GGML_OP_DUP:
24202424
ggml_cuda_dup(ctx, dst);
24212425
break;
@@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
38423846
op->src[0]->type == GGML_TYPE_F32 &&
38433847
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
38443848
} break;
3849+
case GGML_OP_SET:
3850+
{
3851+
const ggml_type t = op->type;
3852+
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
3853+
t == op->src[0]->type &&
3854+
t == op->src[1]->type;
3855+
} break;
38453856
case GGML_OP_CPY:
38463857
{
38473858
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set.cu

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "set.cuh"
2+
#include "cpy.cuh"
3+
4+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5+
const ggml_tensor * src0 = dst->src[0];
6+
const ggml_tensor * src1 = dst->src[1];
7+
8+
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
9+
GGML_ASSERT(src1->type == src0->type);
10+
GGML_ASSERT(dst ->type == src0->type);
11+
12+
GGML_ASSERT(ggml_is_contiguous(dst));
13+
GGML_ASSERT(ggml_is_contiguous(src0));
14+
GGML_ASSERT(ggml_is_contiguous(src1));
15+
16+
const size_t nb1 = ((int32_t *) dst->op_params)[0];
17+
const size_t nb2 = ((int32_t *) dst->op_params)[1];
18+
const size_t nb3 = ((int32_t *) dst->op_params)[2];
19+
const size_t offset = ((int32_t *) dst->op_params)[3];
20+
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
21+
22+
if (!inplace) {
23+
ggml_cuda_cpy(ctx, src0, dst);
24+
}
25+
26+
ggml_tensor dst_view = *dst;
27+
dst_view.data = (void *)((char *)dst->data + offset);
28+
dst_view.ne[0] = src1->ne[0];
29+
dst_view.ne[1] = src1->ne[1];
30+
dst_view.ne[2] = src1->ne[2];
31+
dst_view.ne[3] = src1->ne[3];
32+
33+
dst_view.nb[0] = ggml_element_size(dst);
34+
dst_view.nb[1] = nb1;
35+
dst_view.nb[2] = nb2;
36+
dst_view.nb[3] = nb3;
37+
38+
ggml_cuda_cpy(ctx, src1, &dst_view);
39+
}

ggml/src/ggml-cuda/set.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#define CUDA_SET_BLOCK_SIZE 256
6+
7+
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)