Skip to content

Commit 97a5cd1

Browse files
committed
cgraph ok
1 parent 32a62d1 commit 97a5cd1

File tree

1 file changed

+46
-31
lines changed

1 file changed

+46
-31
lines changed

tools/llava/clip.cpp

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
816816
const auto & hparams = model.hparams;
817817

818818
const int patch_size = hparams.patch_size;
819-
const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size));
819+
const int px = img.nx / patch_size;
820+
const int py = img.ny / patch_size;
821+
const int num_patches = px * py;
820822
const int num_pos = num_patches + 1; // +1 for [CLS]
821823
const int hidden_size = hparams.hidden_size;
822824
const int n_head = hparams.n_head;
@@ -849,10 +851,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
849851
ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
850852
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
851853
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
852-
inp = ggml_add(ctx0, inp, model.patch_bias);
853854

854855
// add CLS
855-
inp_raw = ggml_concat(ctx0, inp_raw, model.class_embedding, 0);
856+
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
856857

857858
// 2D input positions
858859
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos);
@@ -881,31 +882,31 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
881882
ggml_tensor * Q =
882883
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
883884

884-
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
885+
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_pos);
885886
Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false);
886887
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
887888

888889
ggml_tensor * K =
889890
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
890891

891-
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
892+
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_pos);
892893
K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false);
893894
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
894895

895896
ggml_tensor * V =
896897
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
897898

898-
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
899+
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_pos);
899900
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
900901

901902
ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
902903
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
903904

904905
ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
905-
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
906+
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_pos, n_head);
906907
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
907908

908-
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
909+
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_pos);
909910
}
910911

911912
// attention output
@@ -922,8 +923,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
922923
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
923924
}
924925

925-
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
926-
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
926+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
927+
cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
927928

928929
if (ctx->use_silu) {
929930
cur = ggml_silu(ctx0, cur);
@@ -933,8 +934,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
933934
GGML_ABORT("llama4: Unsupported activation");
934935
}
935936

936-
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
937-
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
937+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
938+
cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
938939

939940
// residual 2
940941
cur = ggml_add(ctx0, embeddings, cur);
@@ -950,33 +951,43 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
950951
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
951952
}
952953

953-
// Llama4VisionPixelShuffleMLP
954+
// based on Llama4VisionPixelShuffleMLP
955+
// https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
954956
{
955957
ggml_tensor * cur = embeddings;
958+
const int batch_size = 1; // always 1 for now since we don't support batching
956959
const int scale_factor = model.hparams.proj_scale_factor;
957-
const int n_embd = cur->ne[0];
958-
const int seq = cur->ne[1];
959-
const int bsz = 1; // batch size, always 1 for now since we don't support batching
960-
const int height = std::sqrt(seq);
961-
const int width = std::sqrt(seq);
962960
GGML_ASSERT(scale_factor != 0);
963-
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
961+
962+
// remove CLS by doing a view
963+
cur = ggml_view_3d(ctx0, cur,
964+
hidden_size, num_patches, batch_size,
965+
ggml_row_size(cur->type, hidden_size),
966+
ggml_row_size(cur->type, hidden_size * num_patches), 0);
967+
968+
cur = ggml_reshape_3d(ctx0, cur,
969+
hidden_size * scale_factor,
970+
num_patches / scale_factor,
971+
batch_size);
964972
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
973+
965974
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
966-
n_embd * scale_factor * scale_factor,
967-
height / scale_factor,
968-
width / scale_factor,
969-
bsz);
975+
hidden_size * scale_factor * scale_factor,
976+
py / scale_factor,
977+
px / scale_factor,
978+
batch_size);
970979
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
971-
cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
972-
n_embd * scale_factor * scale_factor,
973-
seq / (scale_factor * scale_factor),
974-
bsz);
975980

976-
cur = ggml_mul_mat(ctx0, model.projection, cur);
981+
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)
982+
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
983+
cur = ggml_gelu(ctx0, cur);
984+
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
977985
embeddings = cur;
978986
}
979987

988+
// based on Llama4MultiModalProjector
989+
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
990+
980991
// build the graph
981992
ggml_build_forward_expand(gf, embeddings);
982993

@@ -3135,6 +3146,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
31353146
const auto & params = ctx->vision_model.hparams;
31363147

31373148
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
3149+
int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
31383150

31393151
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
31403152
n_patches /= 4;
@@ -3158,8 +3170,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
31583170
n_patches = x_patch * y_patch;
31593171
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
31603172
n_patches = 256;
3161-
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
3162-
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
3173+
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3174+
n_patches /= scale_factor;
3175+
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
3176+
n_patches /= (scale_factor * scale_factor);
31633177
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
31643178
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
31653179
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
@@ -3757,8 +3771,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
37573771
case PROJECTOR_TYPE_GEMMA3:
37583772
return ctx->vision_model.mm_input_proj_w->ne[0];
37593773
case PROJECTOR_TYPE_IDEFICS3:
3760-
case PROJECTOR_TYPE_LLAMA4:
37613774
return ctx->vision_model.projection->ne[1];
3775+
case PROJECTOR_TYPE_LLAMA4:
3776+
return ctx->vision_model.mm_model_proj->ne[1];
37623777
default:
37633778
GGML_ABORT("Unknown projector type");
37643779
}

0 commit comments

Comments
 (0)