Skip to content

Commit 4d52d20

Browse files
committed
cuda : implement CUDA backend support for rel-pos and window operations
1 parent 73a186b commit 4d52d20

File tree

8 files changed

+642
-19
lines changed

8 files changed

+642
-19
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,20 @@ extern "C" {
23462346
int h0,
23472347
int w);
23482348

2349+
// reverse of ggml_win_part with explicit output dimensions
2350+
// a: [C, w, w, B*NPY*NPX]
2351+
// result: [C, w0, h0, b0]
2352+
// w0, h0: output width and height (may differ from input due to padding removal)
2353+
// b0: output batch size
2354+
// w: window size (must match the one used in ggml_win_part)
2355+
GGML_API struct ggml_tensor * ggml_win_unpart_ext(
2356+
struct ggml_context * ctx,
2357+
struct ggml_tensor * a,
2358+
int w0,
2359+
int h0,
2360+
int b0,
2361+
int w);
2362+
23492363
GGML_API struct ggml_tensor * ggml_unary(
23502364
struct ggml_context * ctx,
23512365
struct ggml_tensor * a,
@@ -2356,6 +2370,12 @@ extern "C" {
23562370
struct ggml_tensor * a,
23572371
enum ggml_unary_op op);
23582372

2373+
// relative position encoding
2374+
// a: [C, rel_pos_size]
2375+
// res: [C, kh, qh]
2376+
// where rel_pos_size >= qh + kh - 1
2377+
// extracts relative position embeddings for attention
2378+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
23592379
GGML_API struct ggml_tensor * ggml_get_rel_pos(
23602380
struct ggml_context * ctx,
23612381
struct ggml_tensor * a,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8863,13 +8863,13 @@ static void ggml_compute_forward_win_part_f32(
88638863
const int64_t i01 = px*w + i1;
88648864
const int64_t i00 = i0;
88658865

8866-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8867-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
8866+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8867+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
88688868

88698869
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
88708870
*((float *) dp) = 0;
88718871
} else {
8872-
*((float *) dp) = *((float *) sp);
8872+
*((float *) dp) = *((const float *) sp);
88738873
}
88748874
}
88758875
}
@@ -8907,13 +8907,13 @@ static void ggml_compute_forward_win_part_f16(
89078907
const int64_t i01 = px*w + i1;
89088908
const int64_t i00 = i0;
89098909

8910-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8911-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
8910+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8911+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
89128912

89138913
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
89148914
*((ggml_fp16_t *) dp) = 0;
89158915
} else {
8916-
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
8916+
*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
89178917
}
89188918
}
89198919
}
@@ -8981,10 +8981,10 @@ static void ggml_compute_forward_win_unpart_f32(
89818981
const int64_t i01 = i1%w;
89828982
const int64_t i00 = i0;
89838983

8984-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8985-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
8984+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
8985+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
89868986

8987-
*((float *) dp) = *((float *) sp);
8987+
*((float *) dp) = *((const float *) sp);
89888988
}
89898989
}
89908990
}
@@ -9025,10 +9025,10 @@ static void ggml_compute_forward_win_unpart_f16(
90259025
const int64_t i01 = i1%w;
90269026
const int64_t i00 = i0;
90279027

9028-
void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9029-
void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
9028+
const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00;
9029+
char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
90309030

9031-
*((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp);
9031+
*((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp);
90329032
}
90339033
}
90349034
}
@@ -9216,14 +9216,14 @@ static void ggml_compute_forward_get_rel_pos_f32(
92169216

92179217
GGML_TENSOR_UNARY_OP_LOCALS
92189218

9219-
const int64_t w = ne1;
9219+
const int64_t kh = ne1;
92209220

92219221
float * src0_data = (float *) src0->data;
92229222
float * dst_data = (float *) dst->data;
92239223

92249224
for (int64_t i2 = 0; i2 < ne2; ++i2) {
92259225
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9226-
const int64_t pos = (w - i1 - 1) + i2;
9226+
const int64_t pos = (kh - i1 - 1) + i2;
92279227
for (int64_t i0 = 0; i0 < ne0; ++i0) {
92289228
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
92299229
}
@@ -9242,14 +9242,14 @@ static void ggml_compute_forward_get_rel_pos_f16(
92429242

92439243
GGML_TENSOR_UNARY_OP_LOCALS
92449244

9245-
const int64_t w = ne1;
9245+
const int64_t kh = ne1;
92469246

92479247
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
92489248
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
92499249

92509250
for (int64_t i2 = 0; i2 < ne2; ++i2) {
92519251
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9252-
const int64_t pos = (w - i1 - 1) + i2;
9252+
const int64_t pos = (kh - i1 - 1) + i2;
92539253
for (int64_t i0 = 0; i0 < ne0; ++i0) {
92549254
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
92559255
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "ggml-cuda/pad.cuh"
3535
#include "ggml-cuda/pool2d.cuh"
3636
#include "ggml-cuda/quantize.cuh"
37+
#include "ggml-cuda/rel-pos.cuh"
3738
#include "ggml-cuda/rope.cuh"
3839
#include "ggml-cuda/roll.cuh"
3940
#include "ggml-cuda/scale.cuh"
@@ -48,6 +49,7 @@
4849
#include "ggml-cuda/topk-moe.cuh"
4950
#include "ggml-cuda/unary.cuh"
5051
#include "ggml-cuda/upscale.cuh"
52+
#include "ggml-cuda/win.cuh"
5153
#include "ggml-cuda/wkv.cuh"
5254
#include "ggml-cuda/gla.cuh"
5355
#include "ggml-cuda/set.cuh"
@@ -2717,6 +2719,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27172719
case GGML_OP_OPT_STEP_SGD:
27182720
ggml_cuda_opt_step_sgd(ctx, dst);
27192721
break;
2722+
case GGML_OP_WIN_PART:
2723+
ggml_cuda_op_win_part(ctx, dst);
2724+
break;
2725+
case GGML_OP_WIN_UNPART:
2726+
ggml_cuda_op_win_unpart(ctx, dst);
2727+
break;
2728+
case GGML_OP_GET_REL_POS:
2729+
ggml_cuda_op_get_rel_pos(ctx, dst);
2730+
break;
27202731
default:
27212732
return false;
27222733
}
@@ -4152,6 +4163,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
41524163
case GGML_OP_OPT_STEP_ADAMW:
41534164
case GGML_OP_OPT_STEP_SGD:
41544165
return true;
4166+
case GGML_OP_WIN_PART:
4167+
case GGML_OP_WIN_UNPART:
4168+
case GGML_OP_GET_REL_POS:
4169+
switch (op->src[0]->type) {
4170+
case GGML_TYPE_F16:
4171+
case GGML_TYPE_F32:
4172+
case GGML_TYPE_BF16:
4173+
return true;
4174+
default:
4175+
return false;
4176+
}
41554177
default:
41564178
return false;
41574179
}

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#include "common.cuh"
2+
#include "ggml.h"
3+
#include "ggml-cuda/rel-pos.cuh"
4+
5+
/*
6+
7+
static void ggml_compute_forward_get_rel_pos_f16(
8+
const ggml_compute_params * params,
9+
ggml_tensor * dst) {
10+
GGML_UNUSED(params);
11+
12+
const ggml_tensor * src0 = dst->src[0];
13+
14+
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15+
16+
GGML_TENSOR_UNARY_OP_LOCALS
17+
18+
const int64_t kh = ne1;
19+
20+
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
21+
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
22+
23+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
24+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
25+
const int64_t pos = (kh - i1 - 1) + i2;
26+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
27+
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
28+
}
29+
}
30+
}
31+
}
32+
33+
34+
void ggml_compute_forward_get_rel_pos(
35+
const ggml_compute_params * params,
36+
ggml_tensor * dst) {
37+
38+
const ggml_tensor * src0 = dst->src[0];
39+
40+
switch (src0->type) {
41+
case GGML_TYPE_F32:
42+
{
43+
ggml_compute_forward_get_rel_pos_f32(params, dst);
44+
} break;
45+
case GGML_TYPE_F16:
46+
case GGML_TYPE_BF16:
47+
{
48+
ggml_compute_forward_get_rel_pos_f16(params, dst);
49+
} break;
50+
default:
51+
{
52+
GGML_ABORT("fatal error");
53+
}
54+
}
55+
}
56+
57+
struct ggml_tensor * ggml_get_rel_pos(
58+
struct ggml_context * ctx,
59+
struct ggml_tensor * a,
60+
int qh,
61+
int kh) {
62+
GGML_ASSERT(qh + kh - 1 <= a->ne[1]);
63+
64+
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
65+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne);
66+
67+
result->op = GGML_OP_GET_REL_POS;
68+
result->src[0] = a;
69+
70+
return result;
71+
}
72+
73+
*/
74+
75+
template <typename T>
76+
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
77+
int kh = gridDim.x;
78+
int ki = blockIdx.x;
79+
int qi = blockIdx.y;
80+
int pos = (kh - 1) + qi - ki;
81+
82+
int s0 = C;
83+
int s1 = C * kh;
84+
85+
for (int ci = threadIdx.x; ci < C; ci += blockDim.x) {
86+
((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci];
87+
}
88+
}
89+
90+
static unsigned int round_to_pow2(unsigned int v) {
91+
v--;
92+
v |= v >> 1;
93+
v |= v >> 2;
94+
v |= v >> 4;
95+
v |= v >> 8;
96+
v |= v >> 16;
97+
v++;
98+
99+
return v;
100+
}
101+
102+
void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
103+
const ggml_tensor * src0 = dst->src[0];
104+
105+
GGML_TENSOR_UNARY_OP_LOCALS
106+
107+
GGML_ASSERT(src0->type == dst->type);
108+
109+
int C = ne0;
110+
int kh = ne1;
111+
int qh = ne2;
112+
113+
int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
114+
dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 };
115+
116+
const void * src0_d = (const void *)src0->data;
117+
void * dst_d = (void *)dst->data;
118+
cudaStream_t stream = ctx.stream();
119+
120+
switch (src0->type)
121+
{
122+
case GGML_TYPE_F32:
123+
get_rel_pos_kernel<float><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
124+
break;
125+
case GGML_TYPE_F16:
126+
get_rel_pos_kernel<half><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
127+
break;
128+
case GGML_TYPE_BF16:
129+
get_rel_pos_kernel<nv_bfloat16><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
130+
break;
131+
default:
132+
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
133+
break;
134+
}
135+
}

ggml/src/ggml-cuda/rel-pos.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
#include "common.cuh"
3+
4+
#define CUDA_GET_REL_POS_BLOCK_SIZE 256
5+
6+
void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)