Skip to content

Commit 7850716

Browse files
committed
Add Janus Attention Pool with Latent Query support in CLIP model
1 parent 3667a0a commit 7850716

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

examples/llava/clip.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,23 @@ struct clip_vision_model {
571571
struct ggml_tensor * mm_model_ln_kv_b;
572572
struct ggml_tensor * mm_model_ln_post_w;
573573
struct ggml_tensor * mm_model_ln_post_b;
574+
575+
// Janus Attention Pool with Latent Query
576+
struct ggml_tensor * attn_pool_latent;
577+
struct ggml_tensor * attn_pool_q_w;
578+
struct ggml_tensor * attn_pool_q_b;
579+
struct ggml_tensor * attn_pool_k_w;
580+
struct ggml_tensor * attn_pool_k_b;
581+
struct ggml_tensor * attn_pool_v_w;
582+
struct ggml_tensor * attn_pool_v_b;
583+
struct ggml_tensor * attn_pool_proj_w;
584+
struct ggml_tensor * attn_pool_proj_b;
585+
struct ggml_tensor * attn_pool_norm_w;
586+
struct ggml_tensor * attn_pool_norm_b;
587+
struct ggml_tensor * attn_pool_ffn_up_w;
588+
struct ggml_tensor * attn_pool_ffn_up_b;
589+
struct ggml_tensor * attn_pool_ffn_down_w;
590+
struct ggml_tensor * attn_pool_ffn_down_b;
574591
};
575592

576593
struct clip_ctx {
@@ -580,6 +597,7 @@ struct clip_ctx {
580597
bool has_minicpmv_projector = false;
581598
bool has_glm_projector = false;
582599
bool has_qwen2vl_merger = false;
600+
bool has_janus_attn_pool_latent = false;
583601
int minicpmv_version = 2;
584602

585603
struct clip_vision_model vision_model;
@@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
11531171
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
11541172
}
11551173

1174+
// janus attn pool with latent query
1175+
// TODO: Check the ctx0
1176+
else if (ctx->has_janus_attn_pool_latent){
1177+
if (ctx->proj_type == PROJECTOR_TYPE_JANUS) {
1178+
struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1]
1179+
struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent,
1180+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B]
1181+
1182+
struct ggml_tensor* Q = ggml_add(ctx0,
1183+
ggml_mul_mat(ctx0, model.attn_pool_q_w, latent_expanded),
1184+
model.attn_pool_q_b
1185+
);
1186+
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, 1, batch_size);
1187+
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
1188+
Q = ggml_cont(ggml_permute(ctx0, Q, 0, 2, 1, 3));
1189+
Q = ggml_reshape_3d(ctx0, Q, d_head, 1, n_head * batch_size);
1190+
1191+
struct ggml_tensor* K = ggml_add(ctx0,
1192+
ggml_mul_mat(ctx0, model.attn_pool_k_w, embeddings),
1193+
model.attn_pool_k_b
1194+
);
1195+
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
1196+
K = ggml_cont(ggml_permute(ctx0, K, 0, 2, 1, 3));
1197+
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
1198+
1199+
struct ggml_tensor* V = ggml_add(ctx0,
1200+
ggml_mul_mat(ctx0, model.attn_pool_v_w, embeddings),
1201+
model.attn_pool_v_b
1202+
);
1203+
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
1204+
V = ggml_cont(ggml_permute(ctx0, V, 1, 2, 0, 3));
1205+
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
1206+
1207+
struct ggml_tensor* attn_scores = ggml_mul_mat(ctx0, K, Q);
1208+
attn_scores = ggml_soft_max_inplace(ctx0, attn_scores);
1209+
1210+
struct ggml_tensor* attn_output = ggml_mul_mat(ctx0, V, attn_scores);
1211+
attn_output = ggml_reshape_4d(ctx0, attn_output, d_head, 1, n_head, batch_size);
1212+
attn_output = ggml_cont(ggml_permute(ctx0, attn_output, 0, 2, 1, 3));
1213+
attn_output = ggml_cont_3d(ctx0, attn_output, hidden_size, 1, batch_size);
1214+
1215+
attn_output = ggml_add(ctx0,
1216+
ggml_mul_mat(ctx0, model.attn_pool_proj_w, attn_output),
1217+
model.attn_pool_proj_b
1218+
);
1219+
1220+
// MLP: fc1 -> gelu -> norm -> fc2
1221+
// References:
1222+
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13
1223+
struct ggml_tensor * cur = attn_output;
1224+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
1225+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_down_w, cur), model.attn_pool_ffn_down_b);
1226+
cur = ggml_gelu_inplace(ctx0, cur);
1227+
cur = ggml_norm(ctx0, cur, eps);
1228+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_norm_w, cur), model.attn_pool_norm_b);
1229+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.attn_pool_ffn_up_w, cur), model.attn_pool_ffn_up_b);
1230+
// Residual connection
1231+
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98
1232+
attn_output = ggml_add(ctx0, attn_output, cur); // [D, 1, B]
1233+
1234+
// Pooling, select first token
1235+
embeddings = ggml_view_2d(ctx0,
1236+
attn_output,
1237+
attn_output->ne[0],
1238+
attn_output->ne[2],
1239+
attn_output->nb[2]);
1240+
} else {
1241+
GGML_ABORT("fatal error");
1242+
}
1243+
}
1244+
11561245
// build the graph
11571246
ggml_build_forward_expand(gf, embeddings);
11581247

0 commit comments

Comments
 (0)