Skip to content

Commit 448c62e

Browse files
committed
Add Janus Attention Pool support in CLIP model
1 parent 7850716 commit 448c62e

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

examples/llava/clip.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ static std::string format(const char * fmt, ...) {
103103
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
104104
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
105105
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
106+
#define KEY_HAS_JANUS_ATTN_POOL "clip.has_janus_attn_pool"
106107
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
107108
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
108109
#define KEY_USE_GELU "clip.use_gelu"
@@ -170,6 +171,15 @@ static std::string format(const char * fmt, ...) {
170171
#define TN_GLM_BOI_W "adapter.boi"
171172
#define TN_GLM_EOI_W "adapter.eoi"
172173

174+
#define TN_JANUS_ATTN_POOL_LATENT "attn_pool_latent"
175+
#define TN_JANUS_ATTN_POOL_Q "attn_pool_q.%s"
176+
#define TN_JANUS_ATTN_POOL_K "attn_pool_k.%s"
177+
#define TN_JANUS_ATTN_POOL_V "attn_pool_v.%s"
178+
#define TN_JANUS_ATTN_POOL_PROJ "attn_pool_proj.%s"
179+
#define TN_JANUS_ATTN_POOL_FFN_DOWN "attn_pool_ffn_down.%s"
180+
#define TN_JANUS_ATTN_POOL_NORM "attn_pool_norm.%s"
181+
#define TN_JANUS_ATTN_POOL_FFN_UP "attn_pool_ffn_up.%s"
182+
173183

174184
enum projector_type {
175185
PROJECTOR_TYPE_MLP,
@@ -179,6 +189,7 @@ enum projector_type {
179189
PROJECTOR_TYPE_RESAMPLER,
180190
PROJECTOR_TYPE_GLM_EDGE,
181191
PROJECTOR_TYPE_MERGER,
192+
PROJECTOR_TYPE_ATTN_POOL,
182193
PROJECTOR_TYPE_UNKNOWN,
183194
};
184195

@@ -189,6 +200,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
189200
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
190201
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
191202
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
203+
{ PROJECTOR_TYPE_ATTN_POOL, "janus_attn_pool"},
192204
};
193205

194206

@@ -597,7 +609,7 @@ struct clip_ctx {
597609
bool has_minicpmv_projector = false;
598610
bool has_glm_projector = false;
599611
bool has_qwen2vl_merger = false;
600-
bool has_janus_attn_pool_latent = false;
612+
bool has_janus_attn_pool = false;
601613
int minicpmv_version = 2;
602614

603615
struct clip_vision_model vision_model;
@@ -1172,9 +1184,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
11721184
}
11731185

11741186
// janus attn pool with latent query
1175-
// TODO: Check the ctx0
1176-
else if (ctx->has_janus_attn_pool_latent){
1177-
if (ctx->proj_type == PROJECTOR_TYPE_JANUS) {
1187+
// TODO: Check the ctx0 and memory usage
1188+
else if (ctx->has_janus_attn_pool){
1189+
if (ctx->proj_type == PROJECTOR_TYPE_ATTN_POOL) {
11781190
struct ggml_tensor* latent = model.attn_pool_latent; // Should be [D, 1, 1]
11791191
struct ggml_tensor* latent_expanded = ggml_repeat(ctx0, latent,
11801192
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size)); // [D, 1, B]
@@ -1236,7 +1248,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
12361248
attn_output,
12371249
attn_output->ne[0],
12381250
attn_output->ne[2],
1239-
attn_output->nb[2]);
1251+
attn_output->nb[2],
1252+
0);
12401253
} else {
12411254
GGML_ABORT("fatal error");
12421255
}
@@ -1426,6 +1439,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
14261439
if (idx != -1) {
14271440
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
14281441
}
1442+
idx = gguf_find_key(ctx, KEY_HAS_JANUS_ATTN_POOL);
1443+
if (idx != -1) {
1444+
new_clip->has_janus_attn_pool = gguf_get_val_bool(ctx, idx);
1445+
}
14291446
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
14301447

14311448
GGML_ASSERT(new_clip->has_vision_encoder);
@@ -1447,6 +1464,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
14471464
LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
14481465
LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector);
14491466
LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector);
1467+
LOG_INF("%s: qwen2vl_merger: %d\n", __func__, new_clip->has_qwen2vl_merger);
1468+
LOG_INF("%s: janus_attn_pool: %d\n", __func__, new_clip->has_janus_attn_pool);
14501469
LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
14511470
LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
14521471
}
@@ -1732,6 +1751,24 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
17321751
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
17331752
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
17341753
}
1754+
else if (new_clip->proj_type == KEY_HAS_JANUS_ATTN_POOL) {
1755+
vision_model.attn_pool_latent = get_tensor(new_clip->ctx_data, TN_JANUS_ATTN_POOL_LATENT);
1756+
vision_model.attn_pool_q_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_Q, "weight"));
1757+
vision_model.attn_pool_q_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_Q, "bias"));
1758+
vision_model.attn_pool_k_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_K, "weight"));
1759+
vision_model.attn_pool_k_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_K, "bias"));
1760+
vision_model.attn_pool_v_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_V, "weight"));
1761+
vision_model.attn_pool_v_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_V, "bias"));
1762+
vision_model.attn_pool_proj_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_PROJ, "weight"));
1763+
vision_model.attn_pool_proj_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_PROJ, "bias"));
1764+
vision_model.attn_pool_ffn_down_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_FFN_DOWN, "weight"));
1765+
vision_model.attn_pool_ffn_down_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_FFN_DOWN, "bias"));
1766+
vision_model.attn_pool_norm_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_NORM, "weight"));
1767+
vision_model.attn_pool_norm_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_NORM, "bias"));
1768+
vision_model.attn_pool_ffn_up_w = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_FFN_UP, "weight"));
1769+
vision_model.attn_pool_ffn_up_b = get_tensor(new_clip->ctx_data, format(TN_JANUS_ATTN_POOL_FFN_UP, "bias"));
1770+
1771+
}
17351772
else {
17361773
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
17371774
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));

0 commit comments

Comments
 (0)