From 4c8098c3cf134d1a1958d4944310e470e794db6f Mon Sep 17 00:00:00 2001 From: YaelGitAccount Date: Sat, 25 Oct 2025 20:22:52 +0300 Subject: [PATCH 1/5] feat(cuda): add GGML_OP_SET support Implement CUDA kernel for SET operation with f32 support. All tests passing (14598/14598). --- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++++++ ggml/src/ggml-cuda/set.cu | 39 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/set.cuh | 7 ++++++ 3 files changed, 54 insertions(+) create mode 100644 ggml/src/ggml-cuda/set.cu create mode 100644 ggml/src/ggml-cuda/set.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5a9e54721e463..2cec6f24f5827 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -50,6 +50,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" @@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -3484,6 +3488,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; + case GGML_OP_SET: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000000000..8e42ff77c7f56 --- /dev/null +++ b/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,39 @@ +#include "set.cuh" +__global__ static void set_f32(const float * x, float * dst, const int ne, + const size_t offset) { + const int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) return; + dst[offset + i] = x[i]; +} + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // קבלת המקורות + const ggml_tensor * src0 = dst->src[0]; // הבסיס + const ggml_tensor * src1 = dst->src[1]; // הערכים החדשים + + // בדיקות סוג + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + + const size_t nb1 = ((const int32_t *) dst->op_params)[0]; + const size_t nb2 = ((const int32_t *) dst->op_params)[1]; + const size_t nb3 = ((const int32_t *) dst->op_params)[2]; + const size_t offset = ((const int32_t *) dst->op_params)[3]; + + + // קבלת pointers לזיכרון + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + + // העתקת src0 ל-dst + cudaMemcpyAsync(dst_d, src0_d, ggml_nbytes(dst), + cudaMemcpyDeviceToDevice, ctx.stream()); + + // קריאה ל-kernel לעדכון הערכים + const int ne = ggml_nelements(src1); + const int num_blocks = (ne + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; + set_f32<<>>( + src1_d, dst_d, ne, offset / sizeof(float)); +} \ No newline at end of file diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000000000..40756bb9d541b --- /dev/null +++ b/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file From 955c06feeb5f4c8045faac5e88ec7eade55ebe58 Mon Sep 17 00:00:00 2001 From: YaelGitAccount Date: Sat, 25 Oct 2025 19:56:25 +0300 Subject: [PATCH 2/5] cuda(set): add I32 support; keep F32 --- ggml/src/ggml-cuda/ggml-cuda.cu | 7 +- ggml/src/ggml-cuda/set.cu | 127 +++++++++++++++++++++++--------- ggml/src/ggml-cuda/set.cuh | 2 +- 3 files changed, 98 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 2cec6f24f5827..9d67b14a90e2f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3488,9 +3488,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; - case GGML_OP_SET: + case GGML_OP_SET: { - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + const ggml_type t = op->type; + return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && + t == op->src[0]->type && + t == op->src[1]->type; } break; case GGML_OP_CPY: { diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu index 8e42ff77c7f56..7cc9c0acbeed6 100644 --- a/ggml/src/ggml-cuda/set.cu +++ b/ggml/src/ggml-cuda/set.cu @@ -1,39 +1,96 @@ #include "set.cuh" -__global__ static void set_f32(const float * x, float * dst, const int ne, - const size_t offset) { - const int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i >= ne) return; - dst[offset + i] = x[i]; + +template +static __global__ void k_set(const T * x, const T * y, T * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + T val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val = y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; + } + dst[i] = val; +} + +template +static void set_cuda_impl(const T * x, const T * y, T * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + + const int num_blocks = (n_elements + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; + + k_set<<>>(x, y, dst, n_elements, + ne10, ne11, ne12, ne13, + s1, s2, s3, offset); + } void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // קבלת המקורות - const ggml_tensor * src0 = dst->src[0]; // הבסיס - const ggml_tensor * src1 = dst->src[1]; // הערכים החדשים - - // בדיקות סוג - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - - const size_t nb1 = ((const int32_t *) dst->op_params)[0]; - const size_t nb2 = ((const int32_t *) dst->op_params)[1]; - const size_t nb3 = ((const int32_t *) dst->op_params)[2]; - const size_t offset = ((const int32_t *) dst->op_params)[3]; - - - // קבלת pointers לזיכרון - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; - - // העתקת src0 ל-dst - cudaMemcpyAsync(dst_d, src0_d, ggml_nbytes(dst), - cudaMemcpyDeviceToDevice, ctx.stream()); - - // קריאה ל-kernel לעדכון הערכים - const int ne = ggml_nelements(src1); - const int num_blocks = (ne + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; - set_f32<<>>( - src1_d, dst_d, ne, offset / sizeof(float)); -} \ No newline at end of file + + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == src0->type); + GGML_ASSERT( dst->type == src0->type); + + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); + + const int64_t s1 = dst->op_params[0] / ggml_element_size(dst); + const int64_t s2 = dst->op_params[1] / ggml_element_size(dst); + const int64_t s3 = dst->op_params[2] / ggml_element_size(dst); + const int64_t offset = dst->op_params[3] / ggml_element_size(dst); + const bool inplace = (bool) dst->op_params[4]; + + + // If not inplace, copy src0 to dst first + if (!inplace) { + + CUDA_CHECK(cudaMemcpyAsync(dst->data, src0->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, stream)); + } + + const int64_t n = ggml_nelements(dst); + + + switch (src0->type) { + case GGML_TYPE_F32: + set_cuda_impl(src0_d, src1_d, dst_d, ggml_nelements(dst), + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + s1, s2, s3, offset, stream); + break; + case GGML_TYPE_I32: + set_cuda_impl((const int32_t*)src0_d, (const int32_t*)src1_d, (int32_t*)dst_d, ggml_nelements(dst), + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + s1, s2, s3, offset, stream); + break; + default: + GGML_ABORT("ggml_cuda_op_set: unsupported src0 type %s", ggml_type_name(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh index 40756bb9d541b..dd09529f3e42b 100644 --- a/ggml/src/ggml-cuda/set.cuh +++ b/ggml/src/ggml-cuda/set.cuh @@ -4,4 +4,4 @@ #define CUDA_SET_BLOCK_SIZE 256 -void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 59fb14c708319f959faebfb1a83a8b115ad0882b Mon Sep 17 00:00:00 2001 From: YaelGitAccount Date: Mon, 27 Oct 2025 15:43:53 +0200 Subject: [PATCH 3/5] refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication --- ggml/src/ggml-cuda/set.cu | 105 +++++++++----------------------------- 1 file changed, 24 insertions(+), 81 deletions(-) diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu index 7cc9c0acbeed6..03a36a8883079 100644 --- a/ggml/src/ggml-cuda/set.cu +++ b/ggml/src/ggml-cuda/set.cu @@ -1,96 +1,39 @@ #include "set.cuh" - -template -static __global__ void k_set(const T * x, const T * y, T * dst, const int64_t ne, - const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, - const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { - - const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - if (i >= ne) { - return; - } - - int64_t src1_idx = i - offset; - - int64_t tmp = src1_idx; - const int64_t i13 = tmp / s13; - tmp -= i13 * s13; - const int64_t i12 = tmp / s12; - tmp -= i12 * s12; - const int64_t i11 = tmp / s11; - tmp -= i11 * s11; - const int64_t i10 = tmp; - - T val = x[i]; - if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { - val = y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; - } - dst[i] = val; -} - -template -static void set_cuda_impl(const T * x, const T * y, T * dst, const int64_t n_elements, - const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, - const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { - - const int num_blocks = (n_elements + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; - - k_set<<>>(x, y, dst, n_elements, - ne10, ne11, ne12, ne13, - s1, s2, s3, offset); - -} +#include "cpy.cuh" void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - - const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; - - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); GGML_ASSERT(src1->type == src0->type); - GGML_ASSERT( dst->type == src0->type); + GGML_ASSERT(dst ->type == src0->type); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); - GGML_ASSERT(ggml_is_contiguously_allocated(dst)); - - const int64_t s1 = dst->op_params[0] / ggml_element_size(dst); - const int64_t s2 = dst->op_params[1] / ggml_element_size(dst); - const int64_t s3 = dst->op_params[2] / ggml_element_size(dst); - const int64_t offset = dst->op_params[3] / ggml_element_size(dst); - const bool inplace = (bool) dst->op_params[4]; + const size_t nb1 = ((int32_t *) dst->op_params)[0]; + const size_t nb2 = ((int32_t *) dst->op_params)[1]; + const size_t nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; - // If not inplace, copy src0 to dst first if (!inplace) { - - CUDA_CHECK(cudaMemcpyAsync(dst->data, src0->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, stream)); + ggml_cuda_cpy(ctx, src0, dst); } - const int64_t n = ggml_nelements(dst); - + ggml_tensor dst_view = *dst; + dst_view.data = (void *)((char *)dst->data + offset); + dst_view.ne[0] = src1->ne[0]; + dst_view.ne[1] = src1->ne[1]; + dst_view.ne[2] = src1->ne[2]; + dst_view.ne[3] = src1->ne[3]; - switch (src0->type) { - case GGML_TYPE_F32: - set_cuda_impl(src0_d, src1_d, dst_d, ggml_nelements(dst), - src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - s1, s2, s3, offset, stream); - break; - case GGML_TYPE_I32: - set_cuda_impl((const int32_t*)src0_d, (const int32_t*)src1_d, (int32_t*)dst_d, ggml_nelements(dst), - src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - s1, s2, s3, offset, stream); - break; - default: - GGML_ABORT("ggml_cuda_op_set: unsupported src0 type %s", ggml_type_name(src0->type)); - break; - } -} + dst_view.nb[0] = ggml_element_size(dst); + dst_view.nb[1] = nb1; + dst_view.nb[2] = nb2; + dst_view.nb[3] = nb3; + + ggml_cuda_cpy(ctx, src1, &dst_view, true); +} \ No newline at end of file From d25b5bcaa929d350474584ed121e8c330d14bfc1 Mon Sep 17 00:00:00 2001 From: YaelGitAccount <38328157276@mby.co.il> Date: Mon, 27 Oct 2025 23:04:30 +0200 Subject: [PATCH 4/5] Update ggml/src/ggml-cuda/ggml-cuda.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9d67b14a90e2f..ba9be08c0a8b2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3488,7 +3488,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; - case GGML_OP_SET: + case GGML_OP_SET: { const ggml_type t = op->type; return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && From 15f7dc7a74267c4e404d6fb279d252f5cdacffd4 Mon Sep 17 00:00:00 2001 From: YaelGitAccount <38328157276@mby.co.il> Date: Mon, 27 Oct 2025 23:04:49 +0200 Subject: [PATCH 5/5] Update ggml/src/ggml-cuda/set.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-cuda/set.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu index 03a36a8883079..04bfe07ba0336 100644 --- a/ggml/src/ggml-cuda/set.cu +++ b/ggml/src/ggml-cuda/set.cu @@ -35,5 +35,5 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { dst_view.nb[2] = nb2; dst_view.nb[3] = nb3; - ggml_cuda_cpy(ctx, src1, &dst_view, true); -} \ No newline at end of file + ggml_cuda_cpy(ctx, src1, &dst_view); +}