Skip to content

Commit cfc78c8

Browse files
committed
handle window attention inputs
1 parent 6a8bae0 commit cfc78c8

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

examples/llava/clip.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ static std::string format(const char * fmt, ...) {
110110
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
111111
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
112112
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
113+
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
113114

114115

115116
//
@@ -438,6 +439,7 @@ struct clip_hparams {
438439
std::vector<int32_t> image_grid_pinpoints;
439440
int32_t image_crop_resolution;
440441
std::unordered_set<int32_t> vision_feature_layer;
442+
int32_t attn_window_size;
441443
std::vector<int32_t> full_attn_layers;
442444
};
443445

@@ -1784,8 +1786,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
17841786
auto n_full_attn_layers = gguf_get_arr_n(ctx, idx_full_attn_layers);
17851787
const int * full_attn_layers = (const int *)gguf_get_arr_data(ctx, idx_full_attn_layers);
17861788
hparams.full_attn_layers.assign(full_attn_layers, full_attn_layers + n_full_attn_layers);
1787-
} catch (std::runtime_error & /*e*/) {
17881789

1790+
int idx_window_size = get_key_idx(ctx, KEY_ATTN_WINDOW_SIZE);
1791+
hparams.attn_window_size = gguf_get_val_u32(ctx, idx_window_size);
1792+
} catch (std::runtime_error & /*e*/) {
1793+
hparams.attn_window_size = 0;
17891794
}
17901795

17911796
for (int i = 0; i < 3; ++i) {
@@ -2960,6 +2965,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
29602965
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
29612966
free(data);
29622967
}
2968+
29632969
if (ctx->has_minicpmv_projector) {
29642970
{
29652971
// inspired from siglip:
@@ -3080,6 +3086,64 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30803086
}
30813087
}
30823088

3089+
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
3090+
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
3091+
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
3092+
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
3093+
3094+
const int merge_ratio = 2;
3095+
const int pw = image_size_width / patch_size / merge_ratio;
3096+
const int ph = image_size_height / patch_size / merge_ratio;
3097+
const int grid_window = hparams.attn_window_size / hparams.patch_size / merge_ratio;
3098+
const int ipw = image_size_width / patch_size;
3099+
const int iph = image_size_height / patch_size;
3100+
/*
3101+
pw * ph = number of tokens output by ViT after apply patch merger
3102+
ipw * ipw = number of vision token been processed inside ViT
3103+
*/
3104+
3105+
std::vector<int> idx(ph * pw);
3106+
std::vector<int> inv_idx(ph * pw);
3107+
int dst = 0;
3108+
// [num_vision_tokens, num_vision_tokens] attention mask tensor
3109+
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
3110+
int mask_row = 0;
3111+
3112+
for (int y = 0; y < ph; y+=grid_window)
3113+
{
3114+
for (int x = 0; x < pw; x+=grid_window)
3115+
{
3116+
const int win_h = std::min(grid_window, ph - y);
3117+
const int win_w = std::min(grid_window, pw - x);
3118+
const int dst_0 = dst;
3119+
// group all tokens belong to the same window togather (to a continue range)
3120+
for (int dy = 0; dy < win_h; dy++) {
3121+
for (int dx = 0; dx < win_w; dx++) {
3122+
const int src = (y + dy) * pw + (x + dx);
3123+
assert(src < (int)idx.size());
3124+
assert(dst < (int)inv_idx.size());
3125+
idx[src] = dst;
3126+
inv_idx[dst] = src;
3127+
dst++;
3128+
}
3129+
}
3130+
3131+
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
3132+
int row_offset = mask_row * (ipw * iph);
3133+
std::fill(
3134+
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
3135+
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
3136+
0.0);
3137+
mask_row++;
3138+
}
3139+
}
3140+
}
3141+
3142+
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3143+
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3144+
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
3145+
}
3146+
30833147
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
30843148

30853149
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);

0 commit comments

Comments
 (0)