Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml.h"

Expand Down Expand Up @@ -2233,6 +2234,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GET_ROWS_BACK:
ggml_cuda_op_get_rows_back(ctx, dst);
break;
case GGML_OP_SET:
ggml_cuda_op_set(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
Expand Down Expand Up @@ -3275,6 +3279,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
{
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
} break;
case GGML_OP_SET:
{
return op->type == GGML_TYPE_F32;
} break;
case GGML_OP_SET_ROWS:
{
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
Expand Down
119 changes: 119 additions & 0 deletions ggml/src/ggml-cuda/set.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "ggml-cuda/common.cuh"
#include "set.cuh"

static __global__ void set_f32_cuda_copy(const float * __restrict__ src0,
float * __restrict__ dst,
const size_t ne0,
const size_t ne1,
const size_t ne2,
const size_t ne3,
const size_t nb0,
const size_t nb1,
const size_t nb2,
const size_t nb3) {
const size_t total = ne0 * ne1 * ne2 * ne3;
const size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid >= total) {
return;
}

size_t tmp = gid;

const size_t i0 = tmp % ne0;
tmp /= ne0;
const size_t i1 = tmp % ne1;
tmp /= ne1;
const size_t i2 = tmp % ne2;
tmp /= ne2;
const size_t i3 = tmp;

const size_t pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);

*((float *) ((char *) dst + pos)) = *((const float *) ((const char *) src0 + pos));
}

static __global__ void set_f32_cuda(const float * __restrict__ src1,
float * __restrict__ dst,
const size_t ne10,
const size_t ne11,
const size_t ne12,
const size_t ne13,
const size_t nb10,
const size_t nb11,
const size_t nb12,
const size_t nb13,
const size_t nb0,
const size_t nb1,
const size_t nb2,
const size_t nb3,
const size_t offset

) {
const size_t total = ne10 * ne11 * ne12 * ne13;
const size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid >= total) {
return;
}

size_t tmp = gid;

const size_t i0 = tmp % ne10;
tmp /= ne10;
const size_t i1 = tmp % ne11;
tmp /= ne11;
const size_t i2 = tmp % ne12;
tmp /= ne12;
const size_t i3 = tmp;

size_t dst_offset = offset + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3;
size_t src1_offset = i0 * nb10 + i1 * nb11 + i2 * nb12 + i3 * nb13;

*((float *) ((char *) dst + dst_offset)) = *((const float *) ((const char *) src1 + src1_offset));
}

void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// nb0 is implicitly element_size because src0 and dst are contiguous
const int32_t nb1 = dst->op_params[0];
const int32_t nb2 = dst->op_params[1];
const int32_t nb3 = dst->op_params[2];
const int32_t offset = dst->op_params[3];
const bool inplace = dst->op_params[4];

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(ggml_are_same_shape(src0, dst));

// TODO: support more dtypes.
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

GGML_TENSOR_BINARY_OP_LOCALS01;
const int nb0 = ggml_element_size(dst);

const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;

cudaStream_t stream = ctx.stream();

if (!inplace) {
// copy whole src0 -> dst.
const size_t total = ne00 * ne01 * ne02 * ne03;

const int num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE;

set_f32_cuda_copy<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>(
src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03);
}

// set: src1 -> dst
// set_f32_cuda

const size_t total = ne10 * ne11 * ne12 * ne13;
const size_t num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE;

set_f32_cuda<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>(
src1_d, dst_d, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, offset);
}
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/set.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include "common.cuh"

#define CUDA_SET_BLOCK_SIZE 256

void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading