Skip to content

Commit 893ad9c

Browse files
committed
reshape patch_embeddings_0
1 parent 7341e70 commit 893ad9c

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

tools/llava/clip.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
828828
const int n_layer = hparams.n_layer;
829829
const float eps = hparams.eps;
830830

831-
struct ggml_init_params params = {
831+
ggml_init_params params = {
832832
/*.mem_size =*/ ctx->buf_compute_meta.size(),
833833
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
834834
/*.no_alloc =*/ true,
@@ -837,15 +837,20 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
837837
ggml_context_ptr ctx0_ptr(ggml_init(params));
838838
auto ctx0 = ctx0_ptr.get();
839839

840-
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
840+
ggml_cgraph * gf = ggml_new_graph(ctx0);
841841

842842
// input raw
843-
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
843+
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
844844
ggml_set_name(inp_raw, "inp_raw");
845845
ggml_set_input(inp_raw);
846846

847847
// create patches
848-
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
848+
ggml_tensor * patch_embd_view = ggml_view_4d(ctx0, model.patch_embeddings_0,
849+
hidden_size, patch_size, patch_size, 3,
850+
ggml_row_size(model.patch_embeddings_0->type, hidden_size),
851+
ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size),
852+
ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size * 3), 0);
853+
ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
849854
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
850855
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
851856
inp = ggml_add(ctx0, inp, model.patch_bias);
@@ -854,19 +859,19 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
854859
inp_raw = ggml_concat(ctx0, inp_raw, model.class_embedding, 0);
855860

856861
// 2D input positions
857-
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos);
862+
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos);
858863
ggml_set_name(pos_h, "pos_h");
859864
ggml_set_input(pos_h);
860-
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos);
865+
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos);
861866
ggml_set_name(pos_w, "pos_w");
862867
ggml_set_input(pos_w);
863868

864869
// position embeddings
865-
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
870+
ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
866871

867872
// loop over layers
868873
for (int il = 0; il < n_layer; il++) {
869-
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
874+
ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
870875

871876
// layernorm1
872877
{
@@ -877,30 +882,30 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im
877882
// self-attention
878883
{
879884

880-
struct ggml_tensor * Q =
885+
ggml_tensor * Q =
881886
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
882887

883888
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
884889
Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false);
885890
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
886891

887-
struct ggml_tensor * K =
892+
ggml_tensor * K =
888893
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
889894

890895
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
891896
K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false);
892897
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
893898

894-
struct ggml_tensor * V =
899+
ggml_tensor * V =
895900
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
896901

897902
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
898903
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
899904

900-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
905+
ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
901906
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
902907

903-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
908+
ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
904909
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
905910
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
906911

0 commit comments

Comments
 (0)