Skip to content

Commit 1617cdd

Browse files
committed
add repeat
1 parent 0f90aba commit 1617cdd

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ set(GGML_OPENCL_KERNELS
9494
upscale
9595
unary
9696
pad
97+
repeat
9798
)
9899

99100
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ struct ggml_backend_opencl_context {
301301
cl_program program_mul;
302302
cl_program program_norm;
303303
cl_program program_group_norm; // Added for group_norm
304+
cl_program program_repeat;
304305
cl_program program_pad;
305306
cl_program program_unary;
306307
cl_program program_upscale;
@@ -328,6 +329,7 @@ struct ggml_backend_opencl_context {
328329
cl_kernel kernel_clamp;
329330
cl_kernel kernel_norm;
330331
cl_kernel kernel_group_norm; // Added for group_norm
332+
cl_kernel kernel_repeat;
331333
cl_kernel kernel_pad;
332334
cl_kernel kernel_upscale;
333335
cl_kernel kernel_upscale_bilinear;
@@ -871,6 +873,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
871873
GGML_LOG_CONT(".");
872874
}
873875

876+
// repeat
877+
{
878+
#ifdef GGML_OPENCL_EMBED_KERNELS
879+
const std::string kernel_src {
880+
#include "repeat.cl.h"
881+
};
882+
#else
883+
const std::string kernel_src = read_file("repeat.cl");
884+
#endif
885+
if (!kernel_src.empty()) {
886+
backend_ctx->program_repeat =
887+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
888+
CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err));
889+
GGML_LOG_CONT(".");
890+
} else {
891+
GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n");
892+
backend_ctx->program_repeat = nullptr;
893+
backend_ctx->kernel_repeat = nullptr;
894+
}
895+
}
896+
874897
// pad
875898
{
876899
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2043,6 +2066,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
20432066
case GGML_OP_NORM:
20442067
case GGML_OP_RMS_NORM:
20452068
return true;
2069+
case GGML_OP_REPEAT:
2070+
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
20462071
case GGML_OP_PAD:
20472072
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
20482073
op->src[0]->ne[3] == 1 && op->ne[3] == 1; // Only 3D tensors for now
@@ -3973,6 +3998,75 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0,
39733998
#endif
39743999
}
39754000

4001+
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
4002+
GGML_ASSERT(src0);
4003+
GGML_ASSERT(src0->extra);
4004+
GGML_ASSERT(dst);
4005+
GGML_ASSERT(dst->extra);
4006+
GGML_ASSERT(dst->type == src0->type);
4007+
4008+
UNUSED(src1_shape_def);
4009+
4010+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4011+
cl_command_queue queue = backend_ctx->queue;
4012+
4013+
if (backend_ctx->kernel_repeat == nullptr) {
4014+
GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__);
4015+
return;
4016+
}
4017+
4018+
ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
4019+
ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
4020+
4021+
cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
4022+
cl_ulong off_dst = extra_dst->offset + dst->view_offs;
4023+
4024+
const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3];
4025+
const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3];
4026+
4027+
const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3];
4028+
const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3];
4029+
4030+
cl_kernel kernel = backend_ctx->kernel_repeat;
4031+
4032+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4033+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device));
4034+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0));
4035+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
4036+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0));
4037+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1));
4038+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2));
4039+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3));
4040+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0));
4041+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1));
4042+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2));
4043+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3));
4044+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0));
4045+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1));
4046+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2));
4047+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3));
4048+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0));
4049+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1));
4050+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2));
4051+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3));
4052+
4053+
size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1;
4054+
size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1;
4055+
size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1;
4056+
4057+
size_t global_work_size[] = { gws0, gws1, gws2 };
4058+
4059+
#ifdef GGML_OPENCL_PROFILING
4060+
cl_event evt;
4061+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, &evt));
4062+
4063+
g_profiling_info.emplace_back();
4064+
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, (size_t[3]){0,0,0}, dst);
4065+
#else
4066+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL));
4067+
#endif
4068+
}
4069+
39764070
static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
39774071
GGML_ASSERT(src0);
39784072
GGML_ASSERT(src0->extra);
@@ -5820,6 +5914,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
58205914
}
58215915
func = ggml_cl_group_norm;
58225916
break;
5917+
case GGML_OP_REPEAT:
5918+
if (!any_on_device) {
5919+
return false;
5920+
}
5921+
func = ggml_cl_repeat;
5922+
break;
58235923
case GGML_OP_PAD:
58245924
if (!any_on_device) {
58255925
return false;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
kernel void kernel_repeat(
2+
global const char * src0_data_in,
3+
global char * dst_data_in,
4+
ulong src0_offset,
5+
ulong dst_offset,
6+
int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3,
7+
ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3,
8+
int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3,
9+
ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3
10+
) {
11+
global const char * src0_data = src0_data_in + src0_offset;
12+
global char * dst_data = dst_data_in + dst_offset;
13+
14+
const int d3 = get_global_id(2);
15+
const int d2 = get_global_id(1);
16+
const int d1 = get_global_id(0);
17+
18+
if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) {
19+
return;
20+
}
21+
22+
const int s3 = d3 % src0_ne3;
23+
const int s2 = d2 % src0_ne2;
24+
const int s1 = d1 % src0_ne1;
25+
26+
const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1;
27+
global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1;
28+
29+
for (int d0 = 0; d0 < dst_ne0; ++d0) {
30+
// Determine source index for dimension 0 based on tiling/broadcasting.
31+
const int s0 = d0 % src0_ne0;
32+
33+
const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
34+
global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0;
35+
for (int k = 0; k < src0_nb0; ++k) {
36+
current_dst_el_ptr[k] = current_src_el_ptr[k];
37+
}
38+
}
39+
}

0 commit comments

Comments
 (0)