Skip to content

Commit 01b3f68

Browse files
author
sangchengmeng
committed
0401fix
1 parent 339d98e commit 01b3f68

File tree

16 files changed

+85
-58
lines changed

16 files changed

+85
-58
lines changed

lightllm/models/internvl/model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
InternVLLlamaPreAndPostLayerWeight,
1212
InternVLPhi3PreAndPostLayerWeight,
1313
)
14+
from lightllm.server.core.objs import SamplingParams
1415
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
1516
from lightllm.models.llava.llava_visual import LlavaVisionModel
1617

@@ -40,20 +41,29 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4041
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
4142
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4243

43-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
44-
if img.extra_params["image_patch_max_num"] > 0:
44+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
45+
if image_max_patch_num >= 0:
46+
img.extra_params["image_patch_max_num"] = image_max_patch_num
4547
return
46-
if num_images == 1:
47-
img.extra_params["image_patch_max_num"] = 12
48-
elif num_images > 1 and num_images <= 6:
49-
img.extra_params["image_patch_max_num"] = 6
50-
elif num_images > 6:
51-
img.extra_params["image_patch_max_num"] = 0
48+
elif os.getenv("MAX_PATCH_NUM"):
49+
img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM"))
50+
return
51+
else:
52+
num_images = len(multi_params.images)
53+
if num_images == 1:
54+
img.extra_params["image_patch_max_num"] = 12
55+
elif num_images > 1 and num_images <= 6:
56+
img.extra_params["image_patch_max_num"] = 6
57+
elif num_images > 6:
58+
img.extra_params["image_patch_max_num"] = 0
5259
return
5360

5461
def get_image_token_length(self, img: ImageItem):
5562
return (
56-
self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length
63+
self.get_image_patch_func(
64+
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
65+
)
66+
* self.image_length
5767
)
5868

5969
# only change the impl of the encode func:

lightllm/models/llava/llava_visual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def encode(self, images: List[ImageItem]):
138138
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
139139
img_tensors.append(t)
140140
else:
141-
raise Exception("Unsupport input types: {} for {}".format(type(item), item))
141+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
142142

143143
cur_num = img_tensors[-1].shape[0]
144144
valid_ids.append([valid_id, valid_id + cur_num])

lightllm/models/llava/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def __init__(self, tokenizer, model_cfg):
3333
self.image_length = (image_size // patch_size) ** 2
3434
self.skip_start = model_cfg.get("skip_start", True)
3535

36-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
36+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
3737
return
38+
3839
def get_image_token_length(self, img: ImageItem):
3940
return self.image_length
4041

lightllm/models/qwen2_vl/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
3131
self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
3232
self.image_token_id = kwargs["model_cfg"]["image_token_id"]
3333

34-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
34+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
3535
return
36-
36+
3737
def get_image_token_length(self, img: ImageItem):
3838
width = img.image_w
3939
height = img.image_h

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from transformers import AutoProcessor
4242
from safetensors import safe_open
4343
from transformers.utils import TensorType
44-
from lightllm.server.multimodal_params import MultimodalParams,ImageItem
44+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
4545
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
4646

4747

@@ -445,7 +445,7 @@ def encode(self, images: List[ImageItem]):
445445
img_tensors.append(pixel_values)
446446
img_grids.append(image_grid_thw)
447447
else:
448-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
448+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
449449

450450
# must devide merge_length
451451
cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)

lightllm/models/qwen_vl/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, tokenizer, model_cfg):
1919
# <imgpad>: 151859
2020
self.image_length = model_cfg["visual"].get("n_queries", 256)
2121

22-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
22+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
2323
return
2424

2525
def _list_find(self, input_list, target, start_idx):

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def flash_attention_v3_fwd(
204204

205205
except ImportError:
206206
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
207+
_flash_attn_v3_available = False
207208

208209

209210
def flash_attention_fwd(q, k, v, o):

lightllm/server/api_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
4545
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
4646
from lightllm.utils.envs_utils import get_env_start_args
47-
from lightllm.server.embed_cache.utils import image2base64
47+
from lightllm.utils.image_utils import image2base64
4848

4949
from .api_models import (
5050
ChatCompletionRequest,

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
top_p: float = None,
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
40+
image_max_patch_num: int = -1,
4041
max_new_tokens: int = 16,
4142
min_new_tokens: int = 1,
4243
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
@@ -75,6 +76,7 @@ def __init__(
7576
self.top_p = top_p if top_p is not None else SamplingParams._top_p
7677
self.top_k = top_k if top_k is not None else SamplingParams._top_k
7778
self.ignore_eos = ignore_eos
79+
self.image_max_patch_num = image_max_patch_num
7880
self.max_new_tokens = max_new_tokens
7981
self.min_new_tokens = min_new_tokens
8082
self.stop_sequences = stop_sequences if stop_sequences is not None else SamplingParams._stop_sequences
@@ -254,6 +256,7 @@ def to_dict(self):
254256
ret["temperature"] = self.temperature
255257
ret["top_p"] = self.top_p
256258
ret["top_k"] = self.top_k
259+
ret["image_max_patch_num"] = self.image_max_patch_num
257260
ret["min_new_tokens"] = self.min_new_tokens
258261
ret["ignore_eos"] = self.ignore_eos
259262
ret["max_new_tokens"] = self.max_new_tokens

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class SamplingParams(ctypes.Structure):
249249
("top_p", ctypes.c_float),
250250
("top_k", ctypes.c_int),
251251
("ignore_eos", ctypes.c_bool),
252+
("image_max_patch_num", ctypes.c_int),
252253
("max_new_tokens", ctypes.c_int),
253254
("min_new_tokens", ctypes.c_int),
254255
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -294,6 +295,7 @@ def init(self, tokenizer, **kwargs):
294295
self.top_p = kwargs.get("top_p", SamplingParams._top_p)
295296
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
296297
self.ignore_eos = kwargs.get("ignore_eos", False)
298+
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
297299
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
298300
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
299301
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
@@ -424,6 +426,7 @@ def to_dict(self):
424426
"top_p": self.top_p,
425427
"top_k": self.top_k,
426428
"ignore_eos": self.ignore_eos,
429+
"image_max_patch_num": self.image_max_patch_num,
427430
"max_new_tokens": self.max_new_tokens,
428431
"min_new_tokens": self.min_new_tokens,
429432
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),

0 commit comments

Comments
 (0)