Skip to content

Commit 508def2

Browse files
committed
ggml : fix get_rel_pos scaling bugs and update tests
1 parent 72cdf76 commit 508def2

File tree

3 files changed

+169
-7
lines changed

3 files changed

+169
-7
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9218,8 +9218,8 @@ static void ggml_compute_forward_get_rel_pos_f32(
92189218

92199219
const int64_t kh = ne1;
92209220
const int64_t qh = ne2;
9221-
const float k_scale = MAX(qh / kh, 1.0f);
9222-
const float q_scale = MAX(kh / qh, 1.0f);
9221+
const float k_scale = MAX((float)qh / kh, 1.0f);
9222+
const float q_scale = MAX((float)kh / qh, 1.0f);
92239223

92249224
float * src0_data = (float *) src0->data;
92259225
float * dst_data = (float *) dst->data;
@@ -9247,8 +9247,8 @@ static void ggml_compute_forward_get_rel_pos_f16(
92479247

92489248
const int64_t kh = ne1;
92499249
const int64_t qh = ne2;
9250-
const float k_scale = MAX(qh / kh, 1.0f);
9251-
const float q_scale = MAX(kh / qh, 1.0f);
9250+
const float k_scale = MAX((float)qh / kh, 1.0f);
9251+
const float q_scale = MAX((float)kh / qh, 1.0f);
92529252

92539253
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
92549254
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
template <typename T>
77
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
88
int kh = gridDim.x;
9-
int qh = gridDim.x;
10-
float k_scale = MAX(qh / kh, 1.0f);
11-
float q_scale = MAX(kh / qh, 1.0f);
9+
int qh = gridDim.y;
10+
float k_scale = MAX((float)qh / kh, 1.0f);
11+
float q_scale = MAX((float)kh / qh, 1.0f);
1212
int ki = blockIdx.x;
1313
int qi = blockIdx.y;
1414
int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale);

tests/test-backend-ops.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5518,6 +5518,121 @@ struct test_pad_reflect_1d : public test_case {
55185518
}
55195519
};
55205520

5521+
// GGML_OP_WIN_PART
5522+
struct test_win_part : public test_case {
5523+
const ggml_type type;
5524+
const std::array<int64_t, 4> ne_a; // [C, W, H, B]
5525+
const int w; // window size
5526+
const bool v; // view (non-contiguous input)
5527+
5528+
std::string vars() override {
5529+
return VARS_TO_STR4(type, ne_a, w, v);
5530+
}
5531+
5532+
test_win_part(ggml_type type = GGML_TYPE_F32,
5533+
std::array<int64_t, 4> ne_a = {64, 14, 14, 2},
5534+
int w = 7,
5535+
bool v = false)
5536+
: type(type), ne_a(ne_a), w(w), v(v) {}
5537+
5538+
ggml_tensor * build_graph(ggml_context * ctx) override {
5539+
ggml_tensor * a;
5540+
if (v) {
5541+
auto ne = ne_a; ne[0] *= 2; ne[1] *= 2;
5542+
a = ggml_new_tensor(ctx, type, 4, ne.data());
5543+
ggml_set_name(a, "a");
5544+
5545+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3],
5546+
a->nb[1], a->nb[2], a->nb[3], 0);
5547+
ggml_set_name(a, "view_of_a");
5548+
} else {
5549+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5550+
ggml_set_name(a, "a");
5551+
}
5552+
5553+
ggml_tensor * out = ggml_win_part(ctx, a, w);
5554+
ggml_set_name(out, "out");
5555+
5556+
return out;
5557+
}
5558+
};
5559+
5560+
// GGML_OP_WIN_UNPART
5561+
struct test_win_unpart : public test_case {
5562+
const ggml_type type;
5563+
const std::array<int64_t, 4> ne_a; // [C, w, w, NPX*NPY*B]
5564+
const int w0; // output width
5565+
const int h0; // output height
5566+
const int w; // window size
5567+
5568+
std::string vars() override {
5569+
return VARS_TO_STR5(type, ne_a, w0, h0, w);
5570+
}
5571+
5572+
test_win_unpart(ggml_type type = GGML_TYPE_F32,
5573+
std::array<int64_t, 4> ne_a = {64, 7, 7, 8},
5574+
int w0 = 14, int h0 = 14,
5575+
int w = 7)
5576+
: type(type), ne_a(ne_a), w0(w0), h0(h0), w(w) {}
5577+
5578+
ggml_tensor * build_graph(ggml_context * ctx) override {
5579+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5580+
ggml_set_name(a, "a");
5581+
5582+
ggml_tensor * out = ggml_win_unpart(ctx, a, w0, h0, w);
5583+
ggml_set_name(out, "out");
5584+
5585+
return out;
5586+
}
5587+
};
5588+
5589+
// GGML_OP_GET_REL_POS
5590+
struct test_get_rel_pos : public test_case {
5591+
const ggml_type type;
5592+
const int C; // channels
5593+
const int qh; // query height
5594+
const int kh; // key height
5595+
const bool v; // view (non-contiguous input)
5596+
5597+
std::string vars() override {
5598+
return VARS_TO_STR5(type, C, qh, kh, v);
5599+
}
5600+
5601+
test_get_rel_pos(ggml_type type = GGML_TYPE_F32,
5602+
int C = 64,
5603+
int qh = 7,
5604+
int kh = 7,
5605+
bool v = false)
5606+
: type(type), C(C), qh(qh), kh(kh), v(v) {}
5607+
5608+
ggml_tensor * build_graph(ggml_context * ctx) override {
5609+
// Input tensor has relative position embeddings table
5610+
// Shape: [C, 2*max(qh,kh)-1, 1, 1]
5611+
const int64_t ne_a[4] = {C, 2*std::max(qh, kh) - 1, 1, 1};
5612+
5613+
ggml_tensor * a;
5614+
if (v) {
5615+
// Create larger tensor and view into it (non-contiguous)
5616+
int64_t ne_large[4] = {C * 2, 2*std::max(qh, kh) - 1, 1, 1};
5617+
a = ggml_new_tensor(ctx, type, 4, ne_large);
5618+
ggml_set_name(a, "a");
5619+
5620+
a = ggml_view_4d(ctx, a, C, 2*std::max(qh, kh) - 1, 1, 1,
5621+
a->nb[1], a->nb[2], a->nb[3], 0);
5622+
ggml_set_name(a, "view_of_a");
5623+
} else {
5624+
a = ggml_new_tensor(ctx, type, 4, ne_a);
5625+
ggml_set_name(a, "a");
5626+
}
5627+
5628+
// Output shape: [C, kh, qh, 1]
5629+
ggml_tensor * out = ggml_get_rel_pos(ctx, a, qh, kh);
5630+
ggml_set_name(out, "out");
5631+
5632+
return out;
5633+
}
5634+
};
5635+
55215636
// GGML_OP_ROLL
55225637
struct test_roll : public test_case {
55235638
const int shift0;
@@ -7565,6 +7680,53 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75657680
test_cases.emplace_back(new test_pad_ext());
75667681
test_cases.emplace_back(new test_pad_reflect_1d());
75677682
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
7683+
7684+
// Window partition tests
7685+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7686+
for (bool v : {false, true}) {
7687+
// Exact division: 14x14 -> 2x2 windows of 7x7
7688+
test_cases.emplace_back(new test_win_part(type, {64, 14, 14, 2}, 7, v));
7689+
// With padding: 15x15 -> 3x3 windows of 7x7 (padded)
7690+
test_cases.emplace_back(new test_win_part(type, {64, 15, 15, 2}, 7, v));
7691+
// Single window: 7x7 -> 1x1 windows of 7x7
7692+
test_cases.emplace_back(new test_win_part(type, {64, 7, 7, 1}, 7, v));
7693+
// Larger: 28x28 -> 4x4 windows of 7x7
7694+
test_cases.emplace_back(new test_win_part(type, {128, 28, 28, 4}, 7, v));
7695+
// Window size 8: 16x16 -> 2x2 windows of 8x8
7696+
test_cases.emplace_back(new test_win_part(type, {96, 16, 16, 1}, 8, v));
7697+
}
7698+
}
7699+
7700+
// Window unpartition tests (inverse of partition)
7701+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7702+
// Exact division: 2x2 windows of 7x7 -> 14x14
7703+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 4}, 14, 14, 7));
7704+
// With padding: 3x3 windows of 7x7 -> 15x15
7705+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 9}, 15, 15, 7));
7706+
// Single window: 1x1 windows of 7x7 -> 7x7
7707+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 1}, 7, 7, 7));
7708+
// Larger: 4x4 windows of 7x7 -> 28x28
7709+
test_cases.emplace_back(new test_win_unpart(type, {128, 7, 7, 16}, 28, 28, 7));
7710+
// Window size 8: 2x2 windows of 8x8 -> 16x16
7711+
test_cases.emplace_back(new test_win_unpart(type, {96, 8, 8, 4}, 16, 16, 8));
7712+
}
7713+
7714+
// Relative position embedding tests (used in SAM)
7715+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7716+
for (bool v : {false, true}) {
7717+
// Square small: 3x3 attention
7718+
test_cases.emplace_back(new test_get_rel_pos(type, 5, 3, 3, v));
7719+
// Square medium: 7x7 attention (typical SAM)
7720+
test_cases.emplace_back(new test_get_rel_pos(type, 13, 7, 7, v));
7721+
// Square large: 14x14 attention
7722+
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 14, v));
7723+
// Rectangular: 14x7 attention
7724+
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 7, v));
7725+
// Edge case: 1x1 attention (minimum)
7726+
test_cases.emplace_back(new test_get_rel_pos(type, 1, 1, 1, v));
7727+
}
7728+
}
7729+
75687730
test_cases.emplace_back(new test_roll());
75697731
test_cases.emplace_back(new test_arange());
75707732
test_cases.emplace_back(new test_timestep_embedding());

0 commit comments

Comments
 (0)