Skip to content

Commit 91f35b5

Browse files
committed
cuda : implement CUDA backend support for rel-pos and window operations
1 parent 4e6310c commit 91f35b5

File tree

8 files changed

+642
-19
lines changed

8 files changed

+642
-19
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,6 +2296,20 @@ extern "C" {
22962296
int h0,
22972297
int w);
22982298

2299+
// reverse of ggml_win_part with explicit output dimensions
2300+
// a: [C, w, w, B*NPY*NPX]
2301+
// result: [C, w0, h0, b0]
2302+
// w0, h0: output width and height (may differ from input due to padding removal)
2303+
// b0: output batch size
2304+
// w: window size (must match the one used in ggml_win_part)
2305+
GGML_API struct ggml_tensor * ggml_win_unpart_ext(
2306+
struct ggml_context * ctx,
2307+
struct ggml_tensor * a,
2308+
int w0,
2309+
int h0,
2310+
int b0,
2311+
int w);
2312+
22992313
GGML_API struct ggml_tensor * ggml_unary(
23002314
struct ggml_context * ctx,
23012315
struct ggml_tensor * a,
@@ -2306,6 +2320,12 @@ extern "C" {
23062320
struct ggml_tensor * a,
23072321
enum ggml_unary_op op);
23082322

2323+
// relative position encoding
2324+
// a: [C, rel_pos_size]
2325+
// res: [C, kh, qh]
2326+
// where rel_pos_size >= qh + kh - 1
2327+
// extracts relative position embeddings for attention
2328+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
23092329
GGML_API struct ggml_tensor * ggml_get_rel_pos(
23102330
struct ggml_context * ctx,
23112331
struct ggml_tensor * a,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8969,13 +8969,13 @@ static void ggml_compute_forward_win_part_f32(
89698969
const int64_t i01 = px*w + i1;
89708970
const int64_t i00 = i0;
89718971

8972-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8973-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
8972+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8973+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
89748974

89758975
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
89768976
*((float *) dp) = 0;
89778977
} else {
8978-
*((float *) dp) = *((float *) sp);
8978+
*((float *) dp) = *((const float *) sp);
89798979
}
89808980
}
89818981
}
@@ -9013,13 +9013,13 @@ static void ggml_compute_forward_win_part_f16(
90139013
const int64_t i01 = px*w + i1;
90149014
const int64_t i00 = i0;
90159015

9016-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9017-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9016+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9017+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
90189018

90199019
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
90209020
*((ggml_fp16_t *) dp) = 0;
90219021
} else {
9022-
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
9022+
*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
90239023
}
90249024
}
90259025
}
@@ -9087,10 +9087,10 @@ static void ggml_compute_forward_win_unpart_f32(
90879087
const int64_t i01 = i1%w;
90889088
const int64_t i00 = i0;
90899089

9090-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9091-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9090+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9091+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
90929092

9093-
*((float *) dp) = *((float *) sp);
9093+
*((float *) dp) = *((const float *) sp);
90949094
}
90959095
}
90969096
}
@@ -9131,10 +9131,10 @@ static void ggml_compute_forward_win_unpart_f16(
91319131
const int64_t i01 = i1%w;
91329132
const int64_t i00 = i0;
91339133

9134-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9135-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9134+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9135+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
91369136

9137-
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
9137+
*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
91389138
}
91399139
}
91409140
}
@@ -9314,14 +9314,14 @@ static void ggml_compute_forward_get_rel_pos_f32(
93149314

93159315
GGML_TENSOR_UNARY_OP_LOCALS
93169316

9317-
const int64_t w = ne1;
9317+
const int64_t kh = ne1;
93189318

93199319
float * src0_data = (float *) src0->data;
93209320
float * dst_data = (float *) dst->data;
93219321

93229322
for (int64_t i2 = 0; i2 < ne2; ++i2) {
93239323
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9324-
const int64_t pos = (w - i1 - 1) + i2;
9324+
const int64_t pos = (kh - i1 - 1) + i2;
93259325
for (int64_t i0 = 0; i0 < ne0; ++i0) {
93269326
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
93279327
}
@@ -9340,14 +9340,14 @@ static void ggml_compute_forward_get_rel_pos_f16(
93409340

93419341
GGML_TENSOR_UNARY_OP_LOCALS
93429342

9343-
const int64_t w = ne1;
9343+
const int64_t kh = ne1;
93449344

93459345
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
93469346
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
93479347

93489348
for (int64_t i2 = 0; i2 < ne2; ++i2) {
93499349
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9350-
const int64_t pos = (w - i1 - 1) + i2;
9350+
const int64_t pos = (kh - i1 - 1) + i2;
93519351
for (int64_t i0 = 0; i0 < ne0; ++i0) {
93529352
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
93539353
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ggml-cuda/pad.cuh"
3535
#include "ggml-cuda/pool2d.cuh"
3636
#include "ggml-cuda/quantize.cuh"
37+
#include "ggml-cuda/rel-pos.cuh"
3738
#include "ggml-cuda/rope.cuh"
3839
#include "ggml-cuda/roll.cuh"
3940
#include "ggml-cuda/scale.cuh"
@@ -48,6 +49,7 @@
4849
#include "ggml-cuda/topk-moe.cuh"
4950
#include "ggml-cuda/unary.cuh"
5051
#include "ggml-cuda/upscale.cuh"
52+
#include "ggml-cuda/win.cuh"
5153
#include "ggml-cuda/wkv.cuh"
5254
#include "ggml-cuda/gla.cuh"
5355
#include "ggml-cuda/set.cuh"
@@ -2711,6 +2713,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27112713
case GGML_OP_OPT_STEP_SGD:
27122714
ggml_cuda_opt_step_sgd(ctx, dst);
27132715
break;
2716+
case GGML_OP_WIN_PART:
2717+
ggml_cuda_op_win_part(ctx, dst);
2718+
break;
2719+
case GGML_OP_WIN_UNPART:
2720+
ggml_cuda_op_win_unpart(ctx, dst);
2721+
break;
2722+
case GGML_OP_GET_REL_POS:
2723+
ggml_cuda_op_get_rel_pos(ctx, dst);
2724+
break;
27142725
default:
27152726
return false;
27162727
}
@@ -4091,6 +4102,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
40914102
case GGML_OP_OPT_STEP_ADAMW:
40924103
case GGML_OP_OPT_STEP_SGD:
40934104
return true;
4105+
case GGML_OP_WIN_PART:
4106+
case GGML_OP_WIN_UNPART:
4107+
case GGML_OP_GET_REL_POS:
4108+
switch (op->src[0]->type) {
4109+
case GGML_TYPE_F16:
4110+
case GGML_TYPE_F32:
4111+
case GGML_TYPE_BF16:
4112+
return true;
4113+
default:
4114+
return false;
4115+
}
40944116
default:
40954117
return false;
40964118
}

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
#include "ggml-cuda/rel-pos.cuh"
4+
5+
/*
6+
7+
static void ggml_compute_forward_get_rel_pos_f16(
8+
const ggml_compute_params * params,
9+
ggml_tensor * dst) {
10+
GGML_UNUSED(params);
11+
12+
const ggml_tensor * src0 = dst->src[0];
13+
14+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15+
16+
GGML_TENSOR_UNARY_OP_LOCALS
17+
18+
const int64_t kh = ne1;
19+
20+
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
21+
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
22+
23+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
24+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
25+
const int64_t pos = (kh - i1 - 1) + i2;
26+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
27+
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
28+
}
29+
}
30+
}
31+
}
32+
33+
34+
void ggml_compute_forward_get_rel_pos(
35+
const ggml_compute_params * params,
36+
ggml_tensor * dst) {
37+
38+
const ggml_tensor * src0 = dst->src[0];
39+
40+
switch (src0->type) {
41+
case GGML_TYPE_F32:
42+
{
43+
ggml_compute_forward_get_rel_pos_f32(params, dst);
44+
} break;
45+
case GGML_TYPE_F16:
46+
case GGML_TYPE_BF16:
47+
{
48+
ggml_compute_forward_get_rel_pos_f16(params, dst);
49+
} break;
50+
default:
51+
{
52+
GGML_ABORT("fatal error");
53+
}
54+
}
55+
}
56+
57+
struct ggml_tensor * ggml_get_rel_pos(
58+
struct ggml_context * ctx,
59+
struct ggml_tensor * a,
60+
int qh,
61+
int kh) {
62+
GGML_ASSERT(qh + kh - 1 <= a->ne[1]);
63+
64+
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
65+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne);
66+
67+
result->op = GGML_OP_GET_REL_POS;
68+
result->src[0] = a;
69+
70+
return result;
71+
}
72+
73+
*/
74+
75+
template <typename T>
76+
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
77+
int kh = gridDim.x;
78+
int ki = blockIdx.x;
79+
int qi = blockIdx.y;
80+
int pos = (kh - 1) + qi - ki;
81+
82+
int s0 = C;
83+
int s1 = C * kh;
84+
85+
for (int ci = threadIdx.x; ci < C; ci += blockDim.x) {
86+
((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci];
87+
}
88+
}
89+
90+
static unsigned int round_to_pow2(unsigned int v) {
91+
v--;
92+
v |= v >> 1;
93+
v |= v >> 2;
94+
v |= v >> 4;
95+
v |= v >> 8;
96+
v |= v >> 16;
97+
v++;
98+
99+
return v;
100+
}
101+
102+
void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
103+
const ggml_tensor * src0 = dst->src[0];
104+
105+
GGML_TENSOR_UNARY_OP_LOCALS
106+
107+
GGML_ASSERT(src0->type == dst->type);
108+
109+
int C = ne0;
110+
int kh = ne1;
111+
int qh = ne2;
112+
113+
int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
114+
dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 };
115+
116+
const void * src0_d = (const void *)src0->data;
117+
void * dst_d = (void *)dst->data;
118+
cudaStream_t stream = ctx.stream();
119+
120+
switch (src0->type)
121+
{
122+
case GGML_TYPE_F32:
123+
get_rel_pos_kernel<float><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
124+
break;
125+
case GGML_TYPE_F16:
126+
get_rel_pos_kernel<half><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
127+
break;
128+
case GGML_TYPE_BF16:
129+
get_rel_pos_kernel<nv_bfloat16><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
130+
break;
131+
default:
132+
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
133+
break;
134+
}
135+
}

ggml/src/ggml-cuda/rel-pos.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
#include "common.cuh"
3+
4+
#define CUDA_GET_REL_POS_BLOCK_SIZE 256
5+
6+
void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)