Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
}

// implementation of the 2D RoPE without adding a new op in ggml
// this is not efficient (use double the memory), but works on all backends
static ggml_tensor * build_rope_2d(
ggml_cgraph * gf,
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
) {
ggml_tensor * tmp;
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
Expand All @@ -573,18 +572,24 @@ static ggml_tensor * build_rope_2d(
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
// first half of cur will use 1e-0, 1e-2 (even)
// second half of cur will use 1e-1, 1e-3 (odd)
//
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
// the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);

// first half
ggml_tensor * first;
{
cur = ggml_rope_ext_inplace(
first = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
0);
// first = ggml_cont(ctx0, first);
first = ggml_rope_ext(
ctx0,
cur,
first,
pos_h, // positions
nullptr, // freq factors
n_dim/2, // n_dims
Expand All @@ -593,27 +598,28 @@ static ggml_tensor * build_rope_2d(
);
}

// second half
// second half (write to tmp)
ggml_tensor * second = cur;
{
tmp = ggml_view_3d(ctx0, cur,
second = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
n_dim/2 * ggml_element_size(cur));
tmp = ggml_rope_ext_inplace(
second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you suspect that ggml_rope is not implemented correctly for non-contiguous tensors, please add a test to test-backend-ops that shows the problem.

second = ggml_rope_ext(
ctx0,
tmp,
second,
pos_w, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
freq_scale_odd,
0.0f, 1.0f, 0.0f, 0.0f
);
// calculate inplace (modify cur directly)
ggml_build_forward_expand(gf, tmp);
}

cur = ggml_concat(ctx0, first, second, 0);
return cur;
}

Expand Down Expand Up @@ -682,13 +688,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);

Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));

struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);

K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));

struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
Expand Down
Loading