Skip to content

Commit 6575c14

Browse files
committed
ggml : fix get_rel_pos scaling bugs and update tests
1 parent 3310566 commit 6575c14

File tree

3 files changed

+169
-7
lines changed

3 files changed

+169
-7
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9316,8 +9316,8 @@ static void ggml_compute_forward_get_rel_pos_f32(
93169316

93179317
const int64_t kh = ne1;
93189318
const int64_t qh = ne2;
9319-
const float k_scale = MAX(qh / kh, 1.0f);
9320-
const float q_scale = MAX(kh / qh, 1.0f);
9319+
const float k_scale = MAX((float)qh / kh, 1.0f);
9320+
const float q_scale = MAX((float)kh / qh, 1.0f);
93219321

93229322
float * src0_data = (float *) src0->data;
93239323
float * dst_data = (float *) dst->data;
@@ -9345,8 +9345,8 @@ static void ggml_compute_forward_get_rel_pos_f16(
93459345

93469346
const int64_t kh = ne1;
93479347
const int64_t qh = ne2;
9348-
const float k_scale = MAX(qh / kh, 1.0f);
9349-
const float q_scale = MAX(kh / qh, 1.0f);
9348+
const float k_scale = MAX((float)qh / kh, 1.0f);
9349+
const float q_scale = MAX((float)kh / qh, 1.0f);
93509350

93519351
ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
93529352
ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
template <typename T>
77
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
88
int kh = gridDim.x;
9-
int qh = gridDim.x;
10-
float k_scale = MAX(qh / kh, 1.0f);
11-
float q_scale = MAX(kh / qh, 1.0f);
9+
int qh = gridDim.y;
10+
float k_scale = MAX((float)qh / kh, 1.0f);
11+
float q_scale = MAX((float)kh / qh, 1.0f);
1212
int ki = blockIdx.x;
1313
int qi = blockIdx.y;
1414
int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale);

tests/test-backend-ops.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5455,6 +5455,121 @@ struct test_pad_reflect_1d : public test_case {
54555455
}
54565456
};
54575457

5458+
// GGML_OP_WIN_PART
5459+
struct test_win_part : public test_case {
5460+
const ggml_type type;
5461+
const std::array<int64_t, 4> ne_a; // [C, W, H, B]
5462+
const int w; // window size
5463+
const bool v; // view (non-contiguous input)
5464+
5465+
std::string vars() override {
5466+
return VARS_TO_STR4(type, ne_a, w, v);
5467+
}
5468+
5469+
test_win_part(ggml_type type = GGML_TYPE_F32,
5470+
std::array<int64_t, 4> ne_a = {64, 14, 14, 2},
5471+
int w = 7,
5472+
bool v = false)
5473+
: type(type), ne_a(ne_a), w(w), v(v) {}
5474+
5475+
ggml_tensor * build_graph(ggml_context * ctx) override {
5476+
ggml_tensor * a;
5477+
if (v) {
5478+
auto ne = ne_a; ne[0] *= 2; ne[1] *= 2;
5479+
a = ggml_new_tensor(ctx, type, 4, ne.data());
5480+
ggml_set_name(a, "a");
5481+
5482+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3],
5483+
a->nb[1], a->nb[2], a->nb[3], 0);
5484+
ggml_set_name(a, "view_of_a");
5485+
} else {
5486+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5487+
ggml_set_name(a, "a");
5488+
}
5489+
5490+
ggml_tensor * out = ggml_win_part(ctx, a, w);
5491+
ggml_set_name(out, "out");
5492+
5493+
return out;
5494+
}
5495+
};
5496+
5497+
// GGML_OP_WIN_UNPART
5498+
struct test_win_unpart : public test_case {
5499+
const ggml_type type;
5500+
const std::array<int64_t, 4> ne_a; // [C, w, w, NPX*NPY*B]
5501+
const int w0; // output width
5502+
const int h0; // output height
5503+
const int w; // window size
5504+
5505+
std::string vars() override {
5506+
return VARS_TO_STR5(type, ne_a, w0, h0, w);
5507+
}
5508+
5509+
test_win_unpart(ggml_type type = GGML_TYPE_F32,
5510+
std::array<int64_t, 4> ne_a = {64, 7, 7, 8},
5511+
int w0 = 14, int h0 = 14,
5512+
int w = 7)
5513+
: type(type), ne_a(ne_a), w0(w0), h0(h0), w(w) {}
5514+
5515+
ggml_tensor * build_graph(ggml_context * ctx) override {
5516+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5517+
ggml_set_name(a, "a");
5518+
5519+
ggml_tensor * out = ggml_win_unpart(ctx, a, w0, h0, w);
5520+
ggml_set_name(out, "out");
5521+
5522+
return out;
5523+
}
5524+
};
5525+
5526+
// GGML_OP_GET_REL_POS
5527+
struct test_get_rel_pos : public test_case {
5528+
const ggml_type type;
5529+
const int C; // channels
5530+
const int qh; // query height
5531+
const int kh; // key height
5532+
const bool v; // view (non-contiguous input)
5533+
5534+
std::string vars() override {
5535+
return VARS_TO_STR5(type, C, qh, kh, v);
5536+
}
5537+
5538+
test_get_rel_pos(ggml_type type = GGML_TYPE_F32,
5539+
int C = 64,
5540+
int qh = 7,
5541+
int kh = 7,
5542+
bool v = false)
5543+
: type(type), C(C), qh(qh), kh(kh), v(v) {}
5544+
5545+
ggml_tensor * build_graph(ggml_context * ctx) override {
5546+
// Input tensor has relative position embeddings table
5547+
// Shape: [C, 2*max(qh,kh)-1, 1, 1]
5548+
const int64_t ne_a[4] = {C, 2*std::max(qh, kh) - 1, 1, 1};
5549+
5550+
ggml_tensor * a;
5551+
if (v) {
5552+
// Create larger tensor and view into it (non-contiguous)
5553+
int64_t ne_large[4] = {C * 2, 2*std::max(qh, kh) - 1, 1, 1};
5554+
a = ggml_new_tensor(ctx, type, 4, ne_large);
5555+
ggml_set_name(a, "a");
5556+
5557+
a = ggml_view_4d(ctx, a, C, 2*std::max(qh, kh) - 1, 1, 1,
5558+
a->nb[1], a->nb[2], a->nb[3], 0);
5559+
ggml_set_name(a, "view_of_a");
5560+
} else {
5561+
a = ggml_new_tensor(ctx, type, 4, ne_a);
5562+
ggml_set_name(a, "a");
5563+
}
5564+
5565+
// Output shape: [C, kh, qh, 1]
5566+
ggml_tensor * out = ggml_get_rel_pos(ctx, a, qh, kh);
5567+
ggml_set_name(out, "out");
5568+
5569+
return out;
5570+
}
5571+
};
5572+
54585573
// GGML_OP_ROLL
54595574
struct test_roll : public test_case {
54605575
const int shift0;
@@ -7326,6 +7441,53 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
73267441
test_cases.emplace_back(new test_pad_ext());
73277442
test_cases.emplace_back(new test_pad_reflect_1d());
73287443
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
7444+
7445+
// Window partition tests
7446+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7447+
for (bool v : {false, true}) {
7448+
// Exact division: 14x14 -> 2x2 windows of 7x7
7449+
test_cases.emplace_back(new test_win_part(type, {64, 14, 14, 2}, 7, v));
7450+
// With padding: 15x15 -> 3x3 windows of 7x7 (padded)
7451+
test_cases.emplace_back(new test_win_part(type, {64, 15, 15, 2}, 7, v));
7452+
// Single window: 7x7 -> 1x1 windows of 7x7
7453+
test_cases.emplace_back(new test_win_part(type, {64, 7, 7, 1}, 7, v));
7454+
// Larger: 28x28 -> 4x4 windows of 7x7
7455+
test_cases.emplace_back(new test_win_part(type, {128, 28, 28, 4}, 7, v));
7456+
// Window size 8: 16x16 -> 2x2 windows of 8x8
7457+
test_cases.emplace_back(new test_win_part(type, {96, 16, 16, 1}, 8, v));
7458+
}
7459+
}
7460+
7461+
// Window unpartition tests (inverse of partition)
7462+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7463+
// Exact division: 2x2 windows of 7x7 -> 14x14
7464+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 4}, 14, 14, 7));
7465+
// With padding: 3x3 windows of 7x7 -> 15x15
7466+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 9}, 15, 15, 7));
7467+
// Single window: 1x1 windows of 7x7 -> 7x7
7468+
test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 1}, 7, 7, 7));
7469+
// Larger: 4x4 windows of 7x7 -> 28x28
7470+
test_cases.emplace_back(new test_win_unpart(type, {128, 7, 7, 16}, 28, 28, 7));
7471+
// Window size 8: 2x2 windows of 8x8 -> 16x16
7472+
test_cases.emplace_back(new test_win_unpart(type, {96, 8, 8, 4}, 16, 16, 8));
7473+
}
7474+
7475+
// Relative position embedding tests (used in SAM)
7476+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7477+
for (bool v : {false, true}) {
7478+
// Square small: 3x3 attention
7479+
test_cases.emplace_back(new test_get_rel_pos(type, 5, 3, 3, v));
7480+
// Square medium: 7x7 attention (typical SAM)
7481+
test_cases.emplace_back(new test_get_rel_pos(type, 13, 7, 7, v));
7482+
// Square large: 14x14 attention
7483+
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 14, v));
7484+
// Rectangular: 14x7 attention
7485+
test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 7, v));
7486+
// Edge case: 1x1 attention (minimum)
7487+
test_cases.emplace_back(new test_get_rel_pos(type, 1, 1, 1, v));
7488+
}
7489+
}
7490+
73297491
test_cases.emplace_back(new test_roll());
73307492
test_cases.emplace_back(new test_arange());
73317493
test_cases.emplace_back(new test_timestep_embedding());

0 commit comments

Comments
 (0)