Skip to content

Commit 8caeed5

Browse files
committed
fix merge conflicts
1 parent 844a344 commit 8caeed5

File tree

1 file changed

+94
-11
lines changed

1 file changed

+94
-11
lines changed

tools/mtmd/clip.cpp

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ struct clip_hparams {
194194
};
195195

196196
struct clip_layer {
197-
// layernorm 1 (input norm)
198-
struct ggml_tensor * ln_1_w = nullptr;
199-
struct ggml_tensor * ln_1_b = nullptr;
200-
201197
// attention
202198
ggml_tensor * k_w = nullptr;
203199
ggml_tensor * k_b = nullptr;
@@ -526,7 +522,7 @@ struct clip_graph {
526522
ggml_set_input(pos_w);
527523

528524
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
529-
return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta);
525+
return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
530526
};
531527

532528
ggml_tensor * inp = build_inp();
@@ -940,6 +936,90 @@ struct clip_graph {
940936
return gf;
941937
}
942938

939+
ggml_cgraph * build_llama4() {
940+
GGML_ASSERT(model.class_embedding != nullptr);
941+
GGML_ASSERT(model.position_embeddings != nullptr);
942+
943+
const int n_pos = n_patches + 1;
944+
945+
// 2D input positions
946+
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
947+
ggml_set_name(pos_h, "pos_h");
948+
ggml_set_input(pos_h);
949+
950+
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
951+
ggml_set_name(pos_w, "pos_w");
952+
ggml_set_input(pos_w);
953+
954+
ggml_tensor * inp = build_inp_raw();
955+
956+
// Llama4UnfoldConvolution
957+
{
958+
inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp, patch_size, patch_size, 0, 0, 1, 1);
959+
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
960+
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
961+
cb(inp, "patch_conv", -1);
962+
inp = ggml_add(ctx0, inp, model.patch_bias);
963+
cb(inp, "patch_bias", -1);
964+
}
965+
966+
// add CLS token
967+
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
968+
969+
// build ViT with 2D position embeddings
970+
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
971+
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
972+
};
973+
ggml_tensor * cur = build_vit(
974+
inp, n_pos,
975+
NORM_TYPE_NORMAL,
976+
hparams.ffn_op,
977+
model.position_embeddings,
978+
add_pos);
979+
980+
// remove CLS token
981+
cur = ggml_view_2d(ctx0, cur,
982+
n_embd, n_patches,
983+
ggml_row_size(cur->type, n_embd), 0);
984+
985+
// pixel shuffle
986+
// based on Llama4VisionPixelShuffleMLP
987+
// https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
988+
{
989+
const int scale_factor = model.hparams.proj_scale_factor;
990+
const int bsz = 1; // batch size, always 1 for now since we don't support batching
991+
const int height = n_patches_y;
992+
const int width = n_patches_x;
993+
GGML_ASSERT(scale_factor > 0);
994+
GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
995+
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
996+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
997+
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
998+
n_embd * scale_factor * scale_factor,
999+
height / scale_factor,
1000+
width / scale_factor,
1001+
bsz);
1002+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1003+
// flatten to 2D
1004+
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
1005+
n_embd * scale_factor * scale_factor,
1006+
cur->ne[1] * cur->ne[2]);
1007+
}
1008+
1009+
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)
1010+
{
1011+
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
1012+
cur = ggml_gelu(ctx0, cur);
1013+
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
1014+
cur = ggml_gelu(ctx0, cur);
1015+
}
1016+
1017+
// build the graph
1018+
ggml_build_forward_expand(gf, cur);
1019+
1020+
return gf;
1021+
}
1022+
9431023
// this graph is used by llava, granite and glm
9441024
// due to having embedding_stack (used by granite), we cannot reuse build_vit
9451025
ggml_cgraph * build_llava() {
@@ -1634,9 +1714,10 @@ struct clip_graph {
16341714
static ggml_tensor * build_rope_2d(
16351715
ggml_context * ctx0,
16361716
ggml_tensor * cur,
1637-
ggml_tensor * pos_h,
1638-
ggml_tensor * pos_w,
1639-
const float freq_base
1717+
ggml_tensor * pos_a, // first half
1718+
ggml_tensor * pos_b, // second half
1719+
const float freq_base,
1720+
const bool interleave_freq
16401721
) {
16411722
const int64_t n_dim = cur->ne[0];
16421723
const int64_t n_head = cur->ne[1];
@@ -1650,7 +1731,9 @@ struct clip_graph {
16501731
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
16511732
// then for the second half, we use freq_scale to shift the inv_freq
16521733
// ^ why? replace (2i) with (2i+1) in the above equation
1653-
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
1734+
const float freq_scale_odd = interleave_freq
1735+
? std::pow(freq_base, (float)-2/n_dim)
1736+
: 1.0;
16541737

16551738
// first half
16561739
ggml_tensor * first;
@@ -1663,7 +1746,7 @@ struct clip_graph {
16631746
first = ggml_rope_ext(
16641747
ctx0,
16651748
first,
1666-
pos_h, // positions
1749+
pos_a, // positions
16671750
nullptr, // freq factors
16681751
n_dim/2, // n_dims
16691752
0, 0, freq_base,
@@ -1683,7 +1766,7 @@ struct clip_graph {
16831766
second = ggml_rope_ext(
16841767
ctx0,
16851768
second,
1686-
pos_w, // positions
1769+
pos_b, // positions
16871770
nullptr, // freq factors
16881771
n_dim/2, // n_dims
16891772
0, 0, freq_base,

0 commit comments

Comments
 (0)