Skip to content

Commit 72cdf76

Browse files
committed
ggml : add scaling to get_rel_pos for different query/key heights
1 parent 4d52d20 commit 72cdf76

File tree

3 files changed

+14
-73
lines changed

3 files changed

+14
-73
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9217,13 +9217,16 @@ static void ggml_compute_forward_get_rel_pos_f32(
92179217
GGML_TENSOR_UNARY_OP_LOCALS
92189218

92199219
const int64_t kh = ne1;
9220+
const int64_t qh = ne2;
9221+
const float k_scale = MAX(qh / kh, 1.0f);
9222+
const float q_scale = MAX(kh / qh, 1.0f);
92209223

92219224
float * src0_data = (float *) src0->data;
92229225
float * dst_data = (float *) dst->data;
92239226

92249227
for (int64_t i2 = 0; i2 < ne2; ++i2) {
92259228
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9226-
const int64_t pos = (kh - i1 - 1) + i2;
9229+
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
92279230
for (int64_t i0 = 0; i0 < ne0; ++i0) {
92289231
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
92299232
}
@@ -9243,13 +9246,16 @@ static void ggml_compute_forward_get_rel_pos_f16(
92439246
GGML_TENSOR_UNARY_OP_LOCALS
92449247

92459248
const int64_t kh = ne1;
9249+
const int64_t qh = ne2;
9250+
const float k_scale = MAX(qh / kh, 1.0f);
9251+
const float q_scale = MAX(kh / qh, 1.0f);
92469252

92479253
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
92489254
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
92499255

92509256
for (int64_t i2 = 0; i2 < ne2; ++i2) {
92519257
for (int64_t i1 = 0; i1 < ne1; ++i1) {
9252-
const int64_t pos = (kh - i1 - 1) + i2;
9258+
const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale);
92539259
for (int64_t i0 = 0; i0 < ne0; ++i0) {
92549260
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
92559261
}

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 4 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,16 @@
22
#include "ggml.h"
33
#include "ggml-cuda/rel-pos.cuh"
44

5-
/*
6-
7-
static void ggml_compute_forward_get_rel_pos_f16(
8-
const ggml_compute_params * params,
9-
ggml_tensor * dst) {
10-
GGML_UNUSED(params);
11-
12-
const ggml_tensor * src0 = dst->src[0];
13-
14-
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
15-
16-
GGML_TENSOR_UNARY_OP_LOCALS
17-
18-
const int64_t kh = ne1;
19-
20-
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
21-
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
22-
23-
for (int64_t i2 = 0; i2 < ne2; ++i2) {
24-
for (int64_t i1 = 0; i1 < ne1; ++i1) {
25-
const int64_t pos = (kh - i1 - 1) + i2;
26-
for (int64_t i0 = 0; i0 < ne0; ++i0) {
27-
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
28-
}
29-
}
30-
}
31-
}
32-
33-
34-
void ggml_compute_forward_get_rel_pos(
35-
const ggml_compute_params * params,
36-
ggml_tensor * dst) {
37-
38-
const ggml_tensor * src0 = dst->src[0];
39-
40-
switch (src0->type) {
41-
case GGML_TYPE_F32:
42-
{
43-
ggml_compute_forward_get_rel_pos_f32(params, dst);
44-
} break;
45-
case GGML_TYPE_F16:
46-
case GGML_TYPE_BF16:
47-
{
48-
ggml_compute_forward_get_rel_pos_f16(params, dst);
49-
} break;
50-
default:
51-
{
52-
GGML_ABORT("fatal error");
53-
}
54-
}
55-
}
56-
57-
struct ggml_tensor * ggml_get_rel_pos(
58-
struct ggml_context * ctx,
59-
struct ggml_tensor * a,
60-
int qh,
61-
int kh) {
62-
GGML_ASSERT(qh + kh - 1 <= a->ne[1]);
63-
64-
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
65-
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne);
66-
67-
result->op = GGML_OP_GET_REL_POS;
68-
result->src[0] = a;
69-
70-
return result;
71-
}
72-
73-
*/
745

756
template <typename T>
767
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
778
int kh = gridDim.x;
9+
int qh = gridDim.x;
10+
float k_scale = MAX(qh / kh, 1.0f);
11+
float q_scale = MAX(kh / qh, 1.0f);
7812
int ki = blockIdx.x;
7913
int qi = blockIdx.y;
80-
int pos = (kh - 1) + qi - ki;
14+
int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale);
8115

8216
int s0 = C;
8317
int s1 = C * kh;

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5479,7 +5479,8 @@ struct ggml_tensor * ggml_get_rel_pos(
54795479
struct ggml_tensor * a,
54805480
int qh,
54815481
int kh) {
5482-
GGML_ASSERT(qh + kh - 1 <= a->ne[1]);
5482+
GGML_ASSERT(qh >= 1 && kh >= 1);
5483+
GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
54835484

54845485
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
54855486
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne);

0 commit comments

Comments
 (0)