@@ -571,6 +571,23 @@ struct clip_vision_model {
571571 struct ggml_tensor * mm_model_ln_kv_b;
572572 struct ggml_tensor * mm_model_ln_post_w;
573573 struct ggml_tensor * mm_model_ln_post_b;
574+
575+ // Janus Attention Pool with Latent Query
576+ struct ggml_tensor * attn_pool_latent;
577+ struct ggml_tensor * attn_pool_q_w;
578+ struct ggml_tensor * attn_pool_q_b;
579+ struct ggml_tensor * attn_pool_k_w;
580+ struct ggml_tensor * attn_pool_k_b;
581+ struct ggml_tensor * attn_pool_v_w;
582+ struct ggml_tensor * attn_pool_v_b;
583+ struct ggml_tensor * attn_pool_proj_w;
584+ struct ggml_tensor * attn_pool_proj_b;
585+ struct ggml_tensor * attn_pool_norm_w;
586+ struct ggml_tensor * attn_pool_norm_b;
587+ struct ggml_tensor * attn_pool_ffn_up_w;
588+ struct ggml_tensor * attn_pool_ffn_up_b;
589+ struct ggml_tensor * attn_pool_ffn_down_w;
590+ struct ggml_tensor * attn_pool_ffn_down_b;
574591};
575592
576593struct clip_ctx {
@@ -580,6 +597,7 @@ struct clip_ctx {
580597 bool has_minicpmv_projector = false ;
581598 bool has_glm_projector = false ;
582599 bool has_qwen2vl_merger = false ;
600+ bool has_janus_attn_pool_latent = false ;
583601 int minicpmv_version = 2 ;
584602
585603 struct clip_vision_model vision_model;
@@ -1153,6 +1171,77 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
11531171 embeddings = ggml_add (ctx0, embeddings, model.mm_1_b );
11541172 }
11551173
1174+ // janus attn pool with latent query
1175+ // TODO: Check the ctx0
1176+ else if (ctx->has_janus_attn_pool_latent ){
1177+ if (ctx->proj_type == PROJECTOR_TYPE_JANUS) {
1178+ struct ggml_tensor * latent = model.attn_pool_latent ; // Should be [D, 1, 1]
1179+ struct ggml_tensor * latent_expanded = ggml_repeat (ctx0, latent,
1180+ ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, hidden_size, 1 , batch_size)); // [D, 1, B]
1181+
1182+ struct ggml_tensor * Q = ggml_add (ctx0,
1183+ ggml_mul_mat (ctx0, model.attn_pool_q_w , latent_expanded),
1184+ model.attn_pool_q_b
1185+ );
1186+ Q = ggml_reshape_4d (ctx0, Q, d_head, n_head, 1 , batch_size);
1187+ Q = ggml_scale_inplace (ctx0, Q, 1 .0f / sqrt ((float )d_head));
1188+ Q = ggml_cont (ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
1189+ Q = ggml_reshape_3d (ctx0, Q, d_head, 1 , n_head * batch_size);
1190+
1191+ struct ggml_tensor * K = ggml_add (ctx0,
1192+ ggml_mul_mat (ctx0, model.attn_pool_k_w , embeddings),
1193+ model.attn_pool_k_b
1194+ );
1195+ K = ggml_reshape_4d (ctx0, K, d_head, n_head, num_positions, batch_size);
1196+ K = ggml_cont (ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
1197+ K = ggml_reshape_3d (ctx0, K, d_head, num_positions, n_head * batch_size);
1198+
1199+ struct ggml_tensor * V = ggml_add (ctx0,
1200+ ggml_mul_mat (ctx0, model.attn_pool_v_w , embeddings),
1201+ model.attn_pool_v_b
1202+ );
1203+ V = ggml_reshape_4d (ctx0, V, d_head, n_head, num_positions, batch_size);
1204+ V = ggml_cont (ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
1205+ V = ggml_reshape_3d (ctx0, V, num_positions, d_head, n_head * batch_size);
1206+
1207+ struct ggml_tensor * attn_scores = ggml_mul_mat (ctx0, K, Q);
1208+ attn_scores = ggml_soft_max_inplace (ctx0, attn_scores);
1209+
1210+ struct ggml_tensor * attn_output = ggml_mul_mat (ctx0, V, attn_scores);
1211+ attn_output = ggml_reshape_4d (ctx0, attn_output, d_head, 1 , n_head, batch_size);
1212+ attn_output = ggml_cont (ggml_permute (ctx0, attn_output, 0 , 2 , 1 , 3 ));
1213+ attn_output = ggml_cont_3d (ctx0, attn_output, hidden_size, 1 , batch_size);
1214+
1215+ attn_output = ggml_add (ctx0,
1216+ ggml_mul_mat (ctx0, model.attn_pool_proj_w , attn_output),
1217+ model.attn_pool_proj_b
1218+ );
1219+
1220+ // MLP: fc1 -> gelu -> norm -> fc2
1221+ // References:
1222+ // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py#L13
1223+ struct ggml_tensor * cur = attn_output;
1224+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.attn_pool_norm_w , cur), model.attn_pool_norm_b );
1225+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.attn_pool_ffn_down_w , cur), model.attn_pool_ffn_down_b );
1226+ cur = ggml_gelu_inplace (ctx0, cur);
1227+ cur = ggml_norm (ctx0, cur, eps);
1228+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.attn_pool_norm_w , cur), model.attn_pool_norm_b );
1229+ cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.attn_pool_ffn_up_w , cur), model.attn_pool_ffn_up_b );
1230+ // Residual connection
1231+ // https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/attention_pool.py#L98
1232+ attn_output = ggml_add (ctx0, attn_output, cur); // [D, 1, B]
1233+
1234+ // Pooling, select first token
1235+ embeddings = ggml_view_2d (ctx0,
1236+ attn_output,
1237+ attn_output->ne [0 ],
1238+ attn_output->ne [2 ],
1239+ attn_output->nb [2 ]);
1240+ } else {
1241+ GGML_ABORT (" fatal error" );
1242+ }
1243+ }
1244+
11561245 // build the graph
11571246 ggml_build_forward_expand (gf, embeddings);
11581247
0 commit comments