From 9a53f404e0b71183bdfff4ad53c2927c41d1cbe2 Mon Sep 17 00:00:00 2001 From: Jeemzz Date: Wed, 30 Jul 2025 17:11:14 +0800 Subject: [PATCH 1/3] draft: set cuda --- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++ ggml/src/ggml-cuda/set.cu | 110 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/set.cuh | 7 ++ ggml/src/ggml-cuda/set1.cu | 52 +++++++++++++++ 4 files changed, 177 insertions(+) create mode 100644 ggml/src/ggml-cuda/set.cu create mode 100644 ggml/src/ggml-cuda/set.cuh create mode 100644 ggml/src/ggml-cuda/set1.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1f785796014bd..0a132c2a8d205 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -44,6 +44,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml.h" @@ -2233,6 +2234,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS_BACK: ggml_cuda_op_get_rows_back(ctx, dst); break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; @@ -3275,6 +3279,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; } break; + case GGML_OP_SET: + { + return op->type == GGML_TYPE_F32; + } break; case GGML_OP_SET_ROWS: { return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000000000..9ba9ba90fa2dc --- /dev/null +++ b/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,110 @@ +#include "ggml-cuda/common.cuh" +#include "set.cuh" + +static __global__ void set_f32_cuda_copy(const float * __restrict__ src1, + float * __restrict__ dst, + const size_t ne0, + const size_t ne1, + const size_t ne2, + const size_t ne3, + const int offset, // element‐offset + const int nb1, // stride in elements along dim1 + const int nb2, // stride in elements along dim2 + const int nb3 // stride in elements along dim3 +) { + const size_t total = ne0 * ne1 * ne2 * ne3; + const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= total) { + return; + } + + // unravel into 4D indices (i0 fastest, then i1, i2, i3): + size_t tmp = gid; + const size_t i0 = tmp % ne0; + tmp /= ne0; + const size_t i1 = tmp % ne1; + tmp /= ne1; + const size_t i2 = tmp % ne2; + tmp /= ne2; + const size_t i3 = tmp; // < ne3 + + // compute flat positions with strides + offset + const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3; + + dst[pos] = src1[pos]; +} + +static __global__ void set_f32_cuda(const float * __restrict__ src0, + float * __restrict__ dst, + const size_t ne0, + const size_t ne1, + const size_t ne2, + const size_t ne3, + const int offset, // element‐offset into dst + const int nb1, // stride in elements along dim1 + const int nb2, // stride in elements along dim2 + const int nb3 // stride in elements along dim3 +) { + // src0 is contiguous over ne0*ne1*ne2*ne3 elements + const size_t total = ne0 * ne1 * ne2 * ne3; + const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= total) { + return; + } + + // unravel gid to 4D (same as copy) + size_t tmp = gid; + const size_t i0 = tmp % ne0; + tmp /= ne0; + const size_t i1 = tmp % ne1; + tmp /= ne1; + const size_t i2 = tmp % ne2; + tmp /= ne2; + const size_t i3 = tmp; + + // dst position has the same formula: + const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3; + + // src0 is contiguous: flat index = gid + dst[pos] = src0[gid]; +} + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int32_t nb1 = dst->op_params[0]; + const int32_t nb2 = dst->op_params[1]; + const int32_t nb3 = dst->op_params[2]; + const int32_t offset = dst->op_params[3]; + const bool inplace = dst->op_params[4]; + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + // dims + const size_t ne0 = dst->ne[0]; + const size_t ne1 = dst->ne[1]; + const size_t ne2 = dst->ne[2]; + const size_t ne3 = dst->ne[3]; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const size_t total = ne0 * ne1 * ne2 * ne3; + const int threads = 256; + const int blocks = (total + threads - 1) / threads; + + if (!inplace) { + // copy whole src1→dst + set_f32_cuda_copy<<>>(src1_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3); + } + + // then overwrite from src0→dst at same offsets/strides + set_f32_cuda<<>>(src0_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3); +} diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000000000..dd09529f3e42b --- /dev/null +++ b/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/set1.cu b/ggml/src/ggml-cuda/set1.cu new file mode 100644 index 0000000000000..9711e6a5e37bf --- /dev/null +++ b/ggml/src/ggml-cuda/set1.cu @@ -0,0 +1,52 @@ +#include "ggml-cuda/common.cuh" +#include "set.cuh" + +static __global__ void set_f32_cuda_copy( ...) {} + + + +static __global__ void set_f32_cuda( ...) {} + + + + + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // nb0 is implicitly element_size because src0 and dst are contiguous + const int32_t nb1 = dst->op_params[0]; + const int32_t nb2 = dst->op_params[1]; + const int32_t nb3 = dst->op_params[2]; + const int32_t offset = dst->op_params[3]; + const bool inplace = dst->op_params[4]; + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + // TODO: support more dtypes. + GGML_ASSERT(src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + if (!inplace) { + // copy: src1 -> dst. + set_f32_cuda_copy + } + + // set: src0 -> dst + // set_f32_cuda + + + + + +} From e38e8573d6b754d14ec30d69031ae6146a1bba8a Mon Sep 17 00:00:00 2001 From: Jeemzz Date: Thu, 31 Jul 2025 11:43:37 +0800 Subject: [PATCH 2/3] draft: cuda set op --- ggml/src/ggml-cuda/set.cu | 107 ++++++++++++++++++++----------------- ggml/src/ggml-cuda/set1.cu | 52 ------------------ 2 files changed, 58 insertions(+), 101 deletions(-) delete mode 100644 ggml/src/ggml-cuda/set1.cu diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu index 9ba9ba90fa2dc..32d934c9c63b4 100644 --- a/ggml/src/ggml-cuda/set.cu +++ b/ggml/src/ggml-cuda/set.cu @@ -1,75 +1,78 @@ #include "ggml-cuda/common.cuh" #include "set.cuh" -static __global__ void set_f32_cuda_copy(const float * __restrict__ src1, +static __global__ void set_f32_cuda_copy(const float * __restrict__ src0, float * __restrict__ dst, const size_t ne0, const size_t ne1, const size_t ne2, const size_t ne3, - const int offset, // element‐offset - const int nb1, // stride in elements along dim1 - const int nb2, // stride in elements along dim2 - const int nb3 // stride in elements along dim3 -) { + const size_t nb0, + const size_t nb1, + const size_t nb2, + const size_t nb3) { const size_t total = ne0 * ne1 * ne2 * ne3; const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; if (gid >= total) { return; } - // unravel into 4D indices (i0 fastest, then i1, i2, i3): - size_t tmp = gid; - const size_t i0 = tmp % ne0; + size_t tmp = gid; + + const size_t i0 = tmp % ne0; tmp /= ne0; const size_t i1 = tmp % ne1; tmp /= ne1; const size_t i2 = tmp % ne2; tmp /= ne2; - const size_t i3 = tmp; // < ne3 + const size_t i3 = tmp; - // compute flat positions with strides + offset - const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3; + const size_t pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - dst[pos] = src1[pos]; + *((float *) ((char *) dst + pos)) = *((const float *) ((const char *) src0 + pos)); } -static __global__ void set_f32_cuda(const float * __restrict__ src0, +static __global__ void set_f32_cuda(const float * __restrict__ src1, float * __restrict__ dst, - const size_t ne0, - const size_t ne1, - const size_t ne2, - const size_t ne3, - const int offset, // element‐offset into dst - const int nb1, // stride in elements along dim1 - const int nb2, // stride in elements along dim2 - const int nb3 // stride in elements along dim3 + const size_t ne10, + const size_t ne11, + const size_t ne12, + const size_t ne13, + const size_t nb10, + const size_t nb11, + const size_t nb12, + const size_t nb13, + const size_t nb0, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset + ) { - // src0 is contiguous over ne0*ne1*ne2*ne3 elements - const size_t total = ne0 * ne1 * ne2 * ne3; + const size_t total = ne10 * ne11 * ne12 * ne13; const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; if (gid >= total) { return; } - // unravel gid to 4D (same as copy) - size_t tmp = gid; - const size_t i0 = tmp % ne0; - tmp /= ne0; - const size_t i1 = tmp % ne1; - tmp /= ne1; - const size_t i2 = tmp % ne2; - tmp /= ne2; + size_t tmp = gid; + + const size_t i0 = tmp % ne10; + tmp /= ne10; + const size_t i1 = tmp % ne11; + tmp /= ne11; + const size_t i2 = tmp % ne12; + tmp /= ne12; const size_t i3 = tmp; - // dst position has the same formula: - const size_t pos = offset + i0 + i1 * (size_t) nb1 + i2 * (size_t) nb2 + i3 * (size_t) nb3; + size_t dst_offset = offset + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3; + size_t src1_offset = i0 * nb10 + i1 * nb11 + i2 * nb12 + i3 * nb13; - // src0 is contiguous: flat index = gid - dst[pos] = src0[gid]; + *((float *) ((char *) dst + dst_offset)) = *((const float *) ((const char *) src1 + src1_offset)); } void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // nb0 is implicitly element_size because src0 and dst are contiguous const int32_t nb1 = dst->op_params[0]; const int32_t nb2 = dst->op_params[1]; const int32_t nb3 = dst->op_params[2]; @@ -80,15 +83,14 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src1 = dst->src[1]; GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + // TODO: support more dtypes. GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - // dims - const size_t ne0 = dst->ne[0]; - const size_t ne1 = dst->ne[1]; - const size_t ne2 = dst->ne[2]; - const size_t ne3 = dst->ne[3]; + GGML_TENSOR_BINARY_OP_LOCALS01; + const int nb0 = ggml_element_size(dst); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; @@ -96,15 +98,22 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); - const size_t total = ne0 * ne1 * ne2 * ne3; - const int threads = 256; - const int blocks = (total + threads - 1) / threads; - if (!inplace) { - // copy whole src1→dst - set_f32_cuda_copy<<>>(src1_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3); + // copy whole src0 -> dst. + const size_t total = ne00 * ne01 * ne02 * ne03; + + const int num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; + + set_f32_cuda_copy<<>>( + src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03); } - // then overwrite from src0→dst at same offsets/strides - set_f32_cuda<<>>(src0_d, dst_d, ne0, ne1, ne2, ne3, offset, nb1, nb2, nb3); + // set: src1 -> dst + // set_f32_cuda + + const size_t total = ne10 * ne11 * ne12 * ne13; + const size_t num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; + + set_f32_cuda<<>>( + src1_d, dst_d, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, offset); } diff --git a/ggml/src/ggml-cuda/set1.cu b/ggml/src/ggml-cuda/set1.cu deleted file mode 100644 index 9711e6a5e37bf..0000000000000 --- a/ggml/src/ggml-cuda/set1.cu +++ /dev/null @@ -1,52 +0,0 @@ -#include "ggml-cuda/common.cuh" -#include "set.cuh" - -static __global__ void set_f32_cuda_copy( ...) {} - - - -static __global__ void set_f32_cuda( ...) {} - - - - - -void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - // nb0 is implicitly element_size because src0 and dst are contiguous - const int32_t nb1 = dst->op_params[0]; - const int32_t nb2 = dst->op_params[1]; - const int32_t nb3 = dst->op_params[2]; - const int32_t offset = dst->op_params[3]; - const bool inplace = dst->op_params[4]; - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - // TODO: support more dtypes. - GGML_ASSERT(src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(src[1]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; - - cudaStream_t stream = ctx.stream(); - - if (!inplace) { - // copy: src1 -> dst. - set_f32_cuda_copy - } - - // set: src0 -> dst - // set_f32_cuda - - - - - -} From bfdca26b9e523c50d844516f3175748c9cf1a925 Mon Sep 17 00:00:00 2001 From: Jeemzz Date: Fri, 1 Aug 2025 10:42:45 +0800 Subject: [PATCH 3/3] Replace copy kernel with cudaMemcpyAsync --- ggml/src/ggml-cuda/set.cu | 39 +-------------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu index 32d934c9c63b4..fdd343e26808a 100644 --- a/ggml/src/ggml-cuda/set.cu +++ b/ggml/src/ggml-cuda/set.cu @@ -1,37 +1,6 @@ #include "ggml-cuda/common.cuh" #include "set.cuh" -static __global__ void set_f32_cuda_copy(const float * __restrict__ src0, - float * __restrict__ dst, - const size_t ne0, - const size_t ne1, - const size_t ne2, - const size_t ne3, - const size_t nb0, - const size_t nb1, - const size_t nb2, - const size_t nb3) { - const size_t total = ne0 * ne1 * ne2 * ne3; - const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= total) { - return; - } - - size_t tmp = gid; - - const size_t i0 = tmp % ne0; - tmp /= ne0; - const size_t i1 = tmp % ne1; - tmp /= ne1; - const size_t i2 = tmp % ne2; - tmp /= ne2; - const size_t i3 = tmp; - - const size_t pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - - *((float *) ((char *) dst + pos)) = *((const float *) ((const char *) src0 + pos)); -} - static __global__ void set_f32_cuda(const float * __restrict__ src1, float * __restrict__ dst, const size_t ne10, @@ -100,16 +69,10 @@ void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (!inplace) { // copy whole src0 -> dst. - const size_t total = ne00 * ne01 * ne02 * ne03; - - const int num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; - - set_f32_cuda_copy<<>>( - src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03); + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, stream)); } // set: src1 -> dst - // set_f32_cuda const size_t total = ne10 * ne11 * ne12 * ne13; const size_t num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE;