Skip to content

Commit 7e9fbec

Browse files
committed
mtmd: fix get_rel_pos
1 parent 5e6cf3c commit 7e9fbec

File tree

1 file changed

+90
-84
lines changed

1 file changed

+90
-84
lines changed

tools/mtmd/clip.cpp

Lines changed: 90 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,101 +2467,107 @@ struct clip_graph {
24672467
}
24682468

24692469
// attn: [k_h*k_w, q_h*q_w]
2470-
// rel_h: [q_h, q_w, k_h]
2471-
// rel_w: [q_h, q_w, k_w]
2472-
2473-
static ggml_tensor * add_rel_pos_inplace(
2474-
ggml_context * ctx,
2475-
ggml_tensor * attn,
2476-
ggml_tensor * rel_w,
2477-
ggml_tensor * rel_h,
2478-
int q_size
2479-
) {
2480-
2481-
ggml_tensor *attn_4d =
2482-
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
2483-
2484-
ggml_tensor *rel_h_4d =
2485-
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
2486-
2487-
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
2488-
2489-
ggml_tensor *rel_w_4d =
2490-
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
2491-
2492-
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
2493-
2494-
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
2495-
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
2496-
2497-
2498-
return result;
2499-
}
2500-
2501-
2502-
static ggml_tensor * get_rel_pos(
2503-
ggml_context * ctx,
2504-
ggml_tensor * rel_pos, // [L, C]
2505-
int q_size,
2506-
int k_size
2507-
) {
2508-
2509-
const auto dtype = rel_pos->type;
2510-
2511-
const int64_t L = rel_pos->ne[0]; // length
2512-
const int64_t C = rel_pos->ne[1]; // channels
2513-
2514-
// -------------------------------------------------
2515-
// 1) q_idx ← arange(0..q_size-1) [q_size]
2516-
// 2) k_idx ← arange(0..k_size-1) [k_size]
2517-
// -------------------------------------------------
2518-
2519-
2520-
ggml_tensor * q_coord = ggml_cast(ctx,
2521-
ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f),
2522-
GGML_TYPE_F32); // [q_size]
2523-
ggml_tensor * k_coord = ggml_cast(ctx,
2524-
ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f),
2525-
GGML_TYPE_F32); // [k_size]
2470+
// rel_h: [q_h, q_w, k_h]
2471+
// rel_w: [q_h, q_w, k_w]
2472+
2473+
static ggml_tensor * add_rel_pos_inplace(
2474+
ggml_context * ctx,
2475+
ggml_tensor * attn,
2476+
ggml_tensor * rel_w,
2477+
ggml_tensor * rel_h,
2478+
int q_size
2479+
) {
25262480

2527-
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size);
2528-
q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size]
2481+
ggml_tensor *attn_4d =
2482+
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
25292483

2530-
// broadcast reshape:
2531-
k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size]
2532-
k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
2484+
ggml_tensor *rel_h_4d =
2485+
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
25332486

2534-
// -------------------------------------------------
2535-
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
2536-
// -------------------------------------------------
2537-
rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size]
2487+
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
25382488

2539-
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
2489+
ggml_tensor *rel_w_4d =
2490+
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
25402491

2541-
// -------------------------------------------------
2542-
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
2543-
// -------------------------------------------------
2492+
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
25442493

2545-
ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast<float>(L - 1));
2494+
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
2495+
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
25462496

2547-
ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size]
25482497

2549-
// flatten to 1D for ggml_get_rows
2550-
const int64_t qk = static_cast<int64_t>(q_size) * static_cast<int64_t>(k_size);
2551-
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
2498+
return result;
2499+
}
25522500

2553-
// -------------------------------------------------
2554-
// Gather from rel_pos → [qk, C]
2555-
// -------------------------------------------------
2556-
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
25572501

2558-
// reshape to final output → [q_size, k_size, C]
2559-
ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0],
2560-
q_size,
2561-
k_size);
2502+
static ggml_tensor * get_rel_pos(
2503+
ggml_context * ctx,
2504+
ggml_tensor * rel_pos, // [L, C]
2505+
int q_size,
2506+
int k_size
2507+
) {
2508+
const int64_t C = rel_pos->ne[0]; // channels
2509+
const int64_t L = rel_pos->ne[1]; // length
2510+
2511+
GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
2512+
2513+
// -------------------------------------------------
2514+
// 1) q_idx ← arange(0..q_size-1) [q_size]
2515+
// 2) k_idx ← arange(0..k_size-1) [k_size]
2516+
// -------------------------------------------------
2517+
2518+
// ggml_arange always returns FP32 tensor
2519+
ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f); // [q_size]
2520+
ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f); // [k_size]
2521+
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size);
2522+
2523+
// broadcast reshape:
2524+
q_coord = ggml_cont(ctx,
2525+
ggml_repeat(ctx,
2526+
ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1]
2527+
rel
2528+
)
2529+
); // [q_size, k_size]
2530+
k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
2531+
2532+
// This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with
2533+
// the original implementation.
2534+
if (q_size != k_size) {
2535+
q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f));
2536+
k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f));
2537+
}
25622538

2563-
return out; // [q_size, k_size, C]
2564-
}
2539+
// -------------------------------------------------
2540+
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
2541+
// -------------------------------------------------
2542+
2543+
rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size]
2544+
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
2545+
// Clamp to [0, L-1] range for valid indexing
2546+
rel = ggml_clamp(ctx, rel, 0.0f, static_cast<float>(rel_pos->ne[1] - 1));
2547+
2548+
// -------------------------------------------------
2549+
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
2550+
// -------------------------------------------------
2551+
2552+
ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size]
2553+
2554+
// Gather from rel_pos → [qk, C]
2555+
// -------------------------------------------------
2556+
2557+
// flatten to 1D for ggml_get_rows
2558+
int qk = q_size * k_size;
2559+
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
2560+
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
2561+
2562+
// -------------------------------------------------
2563+
// Gather from rel_pos → [qk, C]
2564+
// -------------------------------------------------
2565+
2566+
ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C]
2567+
2568+
2569+
return out; // [q_size, k_size, C]
2570+
}
25652571

25662572
// Implementation based on approach suggested by Acly
25672573
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091

0 commit comments

Comments
 (0)