|
11 | 11 | InternVLLlamaPreAndPostLayerWeight, |
12 | 12 | InternVLPhi3PreAndPostLayerWeight, |
13 | 13 | ) |
| 14 | +from lightllm.server.core.objs import SamplingParams |
14 | 15 | from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight |
15 | 16 | from lightllm.models.llava.llava_visual import LlavaVisionModel |
16 | 17 |
|
@@ -40,20 +41,29 @@ def __init__(self, tokenizer, model_cfg, **kwargs): |
40 | 41 | self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) |
41 | 42 | self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"]) |
42 | 43 |
|
43 | | - def init_imageItem_extral_params(self, img: ImageItem, num_images): |
44 | | - if img.extra_params["image_patch_max_num"] > 0: |
| 44 | + def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int): |
| 45 | + if image_max_patch_num >= 0: |
| 46 | + img.extra_params["image_patch_max_num"] = image_max_patch_num |
45 | 47 | return |
46 | | - if num_images == 1: |
47 | | - img.extra_params["image_patch_max_num"] = 12 |
48 | | - elif num_images > 1 and num_images <= 6: |
49 | | - img.extra_params["image_patch_max_num"] = 6 |
50 | | - elif num_images > 6: |
51 | | - img.extra_params["image_patch_max_num"] = 0 |
| 48 | + elif os.getenv("MAX_PATCH_NUM"): |
| 49 | + img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM")) |
| 50 | + return |
| 51 | + else: |
| 52 | + num_images = len(multi_params.images) |
| 53 | + if num_images == 1: |
| 54 | + img.extra_params["image_patch_max_num"] = 12 |
| 55 | + elif num_images > 1 and num_images <= 6: |
| 56 | + img.extra_params["image_patch_max_num"] = 6 |
| 57 | + elif num_images > 6: |
| 58 | + img.extra_params["image_patch_max_num"] = 0 |
52 | 59 | return |
53 | 60 |
|
54 | 61 | def get_image_token_length(self, img: ImageItem): |
55 | 62 | return ( |
56 | | - self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length |
| 63 | + self.get_image_patch_func( |
| 64 | + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True |
| 65 | + ) |
| 66 | + * self.image_length |
57 | 67 | ) |
58 | 68 |
|
59 | 69 | # only change the impl of the encode func: |
|
0 commit comments