Skip to content

Commit 77b144a

Browse files
committed
replace KEY_FULLATTN_BLK_IDX with KEY_WIN_ATTN_PATTERN
1 parent f69e9fa commit 77b144a

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

examples/llava/clip-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4747
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
4848
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
49-
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
49+
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
5050
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
5151

5252

examples/llava/clip.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ struct clip_hparams {
171171
int32_t image_crop_resolution;
172172
std::unordered_set<int32_t> vision_feature_layer;
173173
int32_t attn_window_size;
174-
std::vector<int32_t> full_attn_layers;
174+
int32_t n_wa_pattern;
175175
};
176176

177177
struct clip_layer {
@@ -799,7 +799,8 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
799799
const int n_head = hparams.n_head;
800800
const int d_head = hidden_size / n_head;
801801
const float eps = hparams.eps;
802-
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
802+
const int n_wa_pattern = hparams.n_wa_pattern;
803+
const bool use_window_attn = hparams.n_wa_pattern > 0;
803804
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
804805

805806
const int batch_size = imgs.entries.size();
@@ -926,8 +927,7 @@ static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const cli
926927
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
927928

928929
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
929-
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
930-
const bool full_attn = use_window_attn ? inlist : true;
930+
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
931931
if (full_attn) {
932932
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
933933
} else {
@@ -1721,8 +1721,8 @@ struct clip_model_loader {
17211721
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
17221722
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
17231723
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, false);
1724+
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false);
17241725
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
1725-
get_arr_int(KEY_FULLATTN_BLK_IDX, hparams.full_attn_layers, false);
17261726

17271727
{
17281728
std::string mm_patch_merge_type;
@@ -3074,6 +3074,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30743074
bool support_dynamic_size = ctx->has_minicpmv_projector
30753075
|| ctx->has_qwen2vl_merger
30763076
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
3077+
const bool use_window_attn = hparams.n_wa_pattern > 0;
30773078

30783079
const int image_size = hparams.image_size;
30793080
int image_size_width = image_size;
@@ -3335,7 +3336,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33353336
}
33363337
}
33373338

3338-
if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
3339+
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
33393340
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
33403341
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
33413342
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
@@ -3388,9 +3389,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33883389
}
33893390
}
33903391

3391-
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3392-
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3393-
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
3392+
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3393+
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3394+
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
33943395
}
33953396

33963397
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);

examples/llava/qwen2_vl_surgery.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
from typing import Dict
2+
from typing import Dict, List, Optional
33

44
import torch
55
import numpy as np
@@ -20,6 +20,20 @@
2020
def k(raw_key: str, arch: str) -> str:
2121
return raw_key.format(arch=arch)
2222

23+
24+
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
25+
if fullatt_block_indexes is None:
26+
return 0
27+
n_wa = fullatt_block_indexes[0]
28+
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
29+
if b - a - 1 != n_wa:
30+
raise ValueError(
31+
f"window/full attention layer should have fix pattern of "
32+
f"for each full-attention layer followed by {n_wa} window-attention layers"
33+
)
34+
return n_wa + 1
35+
36+
2337
class VL2:
2438

2539
@staticmethod
@@ -152,7 +166,7 @@ def main(args):
152166
raise ValueError()
153167

154168
if args.model_type == "qwen2.5vl":
155-
fout.add_array("clip.vision.fullatt_block_indexes", vcfg.fullatt_block_indexes)
169+
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
156170
fout.add_uint32("clip.vision.window_size", vcfg.window_size)
157171
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
158172
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)

0 commit comments

Comments
 (0)