Skip to content

Commit 6a8bae0

Browse files
committed
implment vision model architecture, gguf convertor
1 parent f4c3dd5 commit 6a8bae0

File tree

2 files changed

+245
-86
lines changed

2 files changed

+245
-86
lines changed

examples/llava/clip.cpp

Lines changed: 141 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ static std::string format(const char * fmt, ...) {
8989
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
9090
#define KEY_USE_GELU "clip.use_gelu"
9191
#define KEY_USE_SILU "clip.use_silu"
92+
#define KEY_USE_GLU_MLP "clip.use_glu_mlp"
93+
#define KEY_USE_RMS_NORM "clip.use_rms_norm"
9294
#define KEY_N_EMBD "clip.%s.embedding_length"
9395
#define KEY_N_FF "clip.%s.feed_forward_length"
9496
#define KEY_N_BLOCK "clip.%s.block_count"
@@ -107,6 +109,7 @@ static std::string format(const char * fmt, ...) {
107109
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
108110
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
109111
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
112+
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
110113

111114

112115
//
@@ -125,6 +128,7 @@ static std::string format(const char * fmt, ...) {
125128
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
126129
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
127130
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
131+
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
128132
#define TN_LN_1 "%s.blk.%d.ln1.%s"
129133
#define TN_LN_2 "%s.blk.%d.ln2.%s"
130134
#define TN_LN_PRE "%s.pre_ln.%s"
@@ -434,6 +438,7 @@ struct clip_hparams {
434438
std::vector<int32_t> image_grid_pinpoints;
435439
int32_t image_crop_resolution;
436440
std::unordered_set<int32_t> vision_feature_layer;
441+
std::vector<int32_t> full_attn_layers;
437442
};
438443

439444
struct clip_layer {
@@ -459,6 +464,9 @@ struct clip_layer {
459464
struct ggml_tensor * ff_o_w;
460465
struct ggml_tensor * ff_o_b;
461466

467+
struct ggml_tensor * ff_g_w = NULL;
468+
struct ggml_tensor * ff_g_b = NULL;
469+
462470
// layernorm 2
463471
struct ggml_tensor * ln_2_w;
464472
struct ggml_tensor * ln_2_b;
@@ -582,6 +590,8 @@ struct clip_ctx {
582590
float image_std[3];
583591
bool use_gelu = false;
584592
bool use_silu = false;
593+
bool use_glu_mlp = false;
594+
bool use_rms_norm = false;
585595
int32_t ftype = 1;
586596

587597
bool has_class_embedding = true;
@@ -833,6 +843,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
833843
const int n_head = hparams.n_head;
834844
const int d_head = hidden_size / n_head;
835845
const float eps = hparams.eps;
846+
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
836847
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
837848

838849
const int batch_size = imgs->size;
@@ -883,8 +894,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
883894
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
884895
inp = ggml_add(ctx0, inp, model.patch_bias);
885896
}
886-
struct ggml_tensor * embeddings = inp;
887-
struct ggml_tensor * pos_embed = nullptr;
897+
struct ggml_tensor * embeddings = inp;
898+
struct ggml_tensor * pos_embed = nullptr;
899+
struct ggml_tensor * window_mask = nullptr;
900+
struct ggml_tensor * window_idx = nullptr;
888901

889902
if (ctx->has_llava_projector) {
890903
// concat class_embeddings and patch_embeddings
@@ -936,6 +949,28 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
936949
const auto & vision_feature_layer = hparams.vision_feature_layer;
937950

938951
// loop over layers
952+
953+
if (use_window_attn) {
954+
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
955+
ggml_set_name(window_idx, "window_idx");
956+
ggml_set_input(window_idx);
957+
958+
// mask for window attention
959+
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
960+
ggml_set_name(window_mask, "window_mask");
961+
ggml_set_input(window_mask);
962+
963+
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
964+
GGML_ASSERT(batch_size == 1);
965+
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
966+
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
967+
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
968+
969+
positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 4 / 4);
970+
positions = ggml_get_rows(ctx0, positions, window_idx);
971+
positions = ggml_reshape_1d(ctx0, positions, num_position_ids);
972+
}
973+
939974
for (int il = 0; il < ctx->max_feature_layer; il++) {
940975
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
941976

@@ -948,9 +983,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
948983
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
949984

950985
// layernorm1
951-
{
986+
if (ctx->use_rms_norm) {
987+
cur = ggml_rms_norm(ctx0, cur, eps);
988+
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
989+
}
990+
else {
952991
cur = ggml_norm(ctx0, cur, eps);
953-
954992
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
955993
model.layers[il].ln_1_b);
956994
}
@@ -991,7 +1029,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
9911029
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
9921030

9931031
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
994-
KQ = ggml_soft_max_inplace(ctx0, KQ);
1032+
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
1033+
const bool full_attn = use_window_attn ? inlist : true;
1034+
if (full_attn) {
1035+
KQ = ggml_soft_max_inplace(ctx0, KQ);
1036+
} else {
1037+
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f);
1038+
}
1039+
9951040
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
9961041
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
9971042
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -1008,25 +1053,50 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10081053
embeddings = cur; // embeddings = residual, cur = hidden_states
10091054

10101055
// layernorm2
1011-
{
1056+
if (ctx->use_rms_norm) {
1057+
cur = ggml_rms_norm(ctx0, cur, eps);
1058+
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
1059+
} else {
10121060
cur = ggml_norm(ctx0, cur, eps);
1013-
10141061
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
10151062
}
10161063

1017-
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
1018-
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
1064+
// mlp
1065+
if (ctx->use_glu_mlp) {
1066+
// ffn_up
1067+
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
1068+
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
10191069

1020-
if (ctx->use_gelu) {
1021-
cur = ggml_gelu_inplace(ctx0, cur);
1022-
} else if (ctx->use_silu) {
1023-
cur = ggml_silu_inplace(ctx0, cur);
1024-
} else {
1025-
cur = ggml_gelu_quick_inplace(ctx0, cur);
1070+
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
1071+
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
1072+
if (ctx->use_gelu) {
1073+
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
1074+
} else if (ctx->use_silu) {
1075+
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
1076+
} else {
1077+
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
1078+
}
1079+
cur = ggml_mul(ctx0, cur_gate, cur_up);
1080+
1081+
// ffn_down
1082+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
1083+
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
10261084
}
1085+
else {
1086+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
1087+
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
10271088

1028-
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
1029-
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
1089+
if (ctx->use_gelu) {
1090+
cur = ggml_gelu_inplace(ctx0, cur);
1091+
} else if (ctx->use_silu) {
1092+
cur = ggml_silu_inplace(ctx0, cur);
1093+
} else {
1094+
cur = ggml_gelu_quick_inplace(ctx0, cur);
1095+
}
1096+
1097+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
1098+
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
1099+
}
10301100

10311101
// residual 2
10321102
cur = ggml_add(ctx0, embeddings, cur);
@@ -1036,10 +1106,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10361106

10371107
// post-layernorm
10381108
if (ctx->has_post_norm) {
1039-
embeddings = ggml_norm(ctx0, embeddings, eps);
1040-
ggml_set_name(embeddings, "post_ln");
1109+
if (ctx->use_rms_norm) {
1110+
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
1111+
ggml_set_name(embeddings, "post_ln");
10411112

1042-
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
1113+
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
1114+
} else {
1115+
embeddings = ggml_norm(ctx0, embeddings, eps);
1116+
ggml_set_name(embeddings, "post_ln");
1117+
1118+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
1119+
}
10431120
}
10441121

10451122
// final layer is a vision feature layer
@@ -1352,6 +1429,18 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
13521429
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
13531430
}
13541431

1432+
if (use_window_attn) {
1433+
struct ggml_tensor * inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
1434+
ggml_set_name(inv_window_idx, "inv_window_idx");
1435+
ggml_set_input(inv_window_idx);
1436+
1437+
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
1438+
GGML_ASSERT(batch_size == 1);
1439+
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
1440+
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
1441+
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
1442+
}
1443+
13551444
// build the graph
13561445
ggml_build_forward_expand(gf, embeddings);
13571446

@@ -1542,6 +1631,20 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
15421631
new_clip->use_silu = false;
15431632
}
15441633

1634+
try {
1635+
idx = get_key_idx(ctx, KEY_USE_GLU_MLP);
1636+
new_clip->use_glu_mlp = gguf_get_val_bool(ctx, idx);
1637+
} catch (std::runtime_error & /*e*/) {
1638+
new_clip->use_glu_mlp = false;
1639+
}
1640+
1641+
try {
1642+
idx = get_key_idx(ctx, KEY_USE_RMS_NORM);
1643+
new_clip->use_rms_norm = gguf_get_val_bool(ctx, idx);
1644+
} catch (std::runtime_error & /*e*/) {
1645+
new_clip->use_rms_norm = false;
1646+
}
1647+
15451648
if (verbosity >= 1) {
15461649
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
15471650
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
@@ -1676,6 +1779,15 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
16761779
const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean);
16771780
const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std);
16781781

1782+
try {
1783+
int idx_full_attn_layers = get_key_idx(ctx, KEY_FULLATTN_BLK_IDX);
1784+
auto n_full_attn_layers = gguf_get_arr_n(ctx, idx_full_attn_layers);
1785+
const int * full_attn_layers = (const int *)gguf_get_arr_data(ctx, idx_full_attn_layers);
1786+
hparams.full_attn_layers.assign(full_attn_layers, full_attn_layers + n_full_attn_layers);
1787+
} catch (std::runtime_error & /*e*/) {
1788+
1789+
}
1790+
16791791
for (int i = 0; i < 3; ++i) {
16801792
new_clip->image_mean[i] = mean_data[i];
16811793
new_clip->image_std[i] = std_data[i];
@@ -1887,10 +1999,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
18871999
layer.q_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q, "v", il, "bias"));
18882000
layer.v_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_V, "v", il, "bias"));
18892001
layer.o_b = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias"));
1890-
layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias"));
1891-
layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias"));
18922002
layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN, "v", il, "bias"));
18932003
layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP, "v", il, "bias"));
2004+
2005+
if (!new_clip->use_rms_norm) {
2006+
layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1, "v", il, "bias"));
2007+
layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2, "v", il, "bias"));
2008+
}
2009+
if (new_clip->use_glu_mlp) {
2010+
layer.ff_g_w = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "weight"));
2011+
layer.ff_g_b = get_tensor(new_clip->ctx_data, format(TN_FFN_GATE, "v", il, "bias"));
2012+
}
18942013
}
18952014
}
18962015

0 commit comments

Comments
 (0)