Skip to content

Commit edb87de

Browse files
author
sangchengmeng
committed
0401fix
1 parent 339d98e commit edb87de

File tree

13 files changed

+67
-54
lines changed

13 files changed

+67
-54
lines changed

lightllm/models/internvl/model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
InternVLLlamaPreAndPostLayerWeight,
1212
InternVLPhi3PreAndPostLayerWeight,
1313
)
14+
from lightllm.server.core.objs import SamplingParams
1415
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
1516
from lightllm.models.llava.llava_visual import LlavaVisionModel
1617

@@ -40,20 +41,29 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4041
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
4142
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4243

43-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
44-
if img.extra_params["image_patch_max_num"] > 0:
44+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
45+
if image_max_patch_num >= 0:
46+
img.extra_params["image_patch_max_num"] = image_max_patch_num
4547
return
46-
if num_images == 1:
47-
img.extra_params["image_patch_max_num"] = 12
48-
elif num_images > 1 and num_images <= 6:
49-
img.extra_params["image_patch_max_num"] = 6
50-
elif num_images > 6:
51-
img.extra_params["image_patch_max_num"] = 0
48+
elif os.getenv("MAX_PATCH_NUM"):
49+
img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM"))
50+
return
51+
else:
52+
num_images = len(multi_params.images)
53+
if num_images == 1:
54+
img.extra_params["image_patch_max_num"] = 12
55+
elif num_images > 1 and num_images <= 6:
56+
img.extra_params["image_patch_max_num"] = 6
57+
elif num_images > 6:
58+
img.extra_params["image_patch_max_num"] = 0
5259
return
5360

5461
def get_image_token_length(self, img: ImageItem):
5562
return (
56-
self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length
63+
self.get_image_patch_func(
64+
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
65+
)
66+
* self.image_length
5767
)
5868

5969
# only change the impl of the encode func:

lightllm/models/llava/llava_visual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def encode(self, images: List[ImageItem]):
138138
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
139139
img_tensors.append(t)
140140
else:
141-
raise Exception("Unsupport input types: {} for {}".format(type(item), item))
141+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
142142

143143
cur_num = img_tensors[-1].shape[0]
144144
valid_ids.append([valid_id, valid_id + cur_num])

lightllm/models/llava/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, tokenizer, model_cfg):
3535

3636
def init_imageItem_extral_params(self, img: ImageItem, num_images):
3737
return
38+
3839
def get_image_token_length(self, img: ImageItem):
3940
return self.image_length
4041

lightllm/models/qwen2_vl/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
3333

3434
def init_imageItem_extral_params(self, img: ImageItem, num_images):
3535
return
36-
36+
3737
def get_image_token_length(self, img: ImageItem):
3838
width = img.image_w
3939
height = img.image_h

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from transformers import AutoProcessor
4242
from safetensors import safe_open
4343
from transformers.utils import TensorType
44-
from lightllm.server.multimodal_params import MultimodalParams,ImageItem
44+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
4545
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
4646

4747

@@ -445,7 +445,7 @@ def encode(self, images: List[ImageItem]):
445445
img_tensors.append(pixel_values)
446446
img_grids.append(image_grid_thw)
447447
else:
448-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
448+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
449449

450450
# must devide merge_length
451451
cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def flash_attention_v3_fwd(
204204

205205
except ImportError:
206206
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
207+
_flash_attn_v3_available = False
207208

208209

209210
def flash_attention_fwd(q, k, v, o):

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
top_p: float = None,
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
40+
image_max_patch_num: int = -1,
4041
max_new_tokens: int = 16,
4142
min_new_tokens: int = 1,
4243
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
@@ -75,6 +76,7 @@ def __init__(
7576
self.top_p = top_p if top_p is not None else SamplingParams._top_p
7677
self.top_k = top_k if top_k is not None else SamplingParams._top_k
7778
self.ignore_eos = ignore_eos
79+
self.image_max_patch_num = image_max_patch_num
7880
self.max_new_tokens = max_new_tokens
7981
self.min_new_tokens = min_new_tokens
8082
self.stop_sequences = stop_sequences if stop_sequences is not None else SamplingParams._stop_sequences
@@ -254,6 +256,7 @@ def to_dict(self):
254256
ret["temperature"] = self.temperature
255257
ret["top_p"] = self.top_p
256258
ret["top_k"] = self.top_k
259+
ret["image_max_patch_num"] = self.image_max_patch_num
257260
ret["min_new_tokens"] = self.min_new_tokens
258261
ret["ignore_eos"] = self.ignore_eos
259262
ret["max_new_tokens"] = self.max_new_tokens

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class SamplingParams(ctypes.Structure):
249249
("top_p", ctypes.c_float),
250250
("top_k", ctypes.c_int),
251251
("ignore_eos", ctypes.c_bool),
252+
("image_max_patch_num", ctypes.c_int),
252253
("max_new_tokens", ctypes.c_int),
253254
("min_new_tokens", ctypes.c_int),
254255
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -294,6 +295,7 @@ def init(self, tokenizer, **kwargs):
294295
self.top_p = kwargs.get("top_p", SamplingParams._top_p)
295296
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
296297
self.ignore_eos = kwargs.get("ignore_eos", False)
298+
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
297299
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
298300
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
299301
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
@@ -424,6 +426,7 @@ def to_dict(self):
424426
"top_p": self.top_p,
425427
"top_k": self.top_k,
426428
"ignore_eos": self.ignore_eos,
429+
"image_max_patch_num": self.image_max_patch_num,
427430
"max_new_tokens": self.max_new_tokens,
428431
"min_new_tokens": self.min_new_tokens,
429432
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,4 @@ def set_item_embed(self, id: int) -> None:
109109
self._records[id].embed = True
110110

111111
def get_item_embed(self, id: int) -> bool:
112-
return self._records[id].embed
112+
return self._records[id].embed

lightllm/server/httpserver/manager.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def __init__(
108108
return
109109

110110
# connect cache server, calculate md5, alloc resource, return uuid
111-
async def _alloc_resource(self, img:ImageItem, num_tokens):
111+
async def _alloc_resource(self, img: ImageItem):
112112
data = img.read()
113+
num_tokens = self.tokenizer.get_image_token_length(img)
113114
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
114115
wait_time = 1
115116
while True:
@@ -126,16 +127,12 @@ async def _alloc_resource(self, img:ImageItem, num_tokens):
126127
await asyncio.sleep(wait_time)
127128
wait_time = min(wait_time + 2, 9)
128129

129-
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams):
130+
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, image_max_patch_num):
130131
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
131132
if self.pd_mode.is_P_or_NORMAL():
132-
num_images = len(multimodal_params.images)
133133
for img in multimodal_params.images:
134-
self.tokenizer.init_imageItem_extral_params(img, num_images)
135-
num_tokens = self.tokenizer.get_image_token_length(img)
136-
record = await self._alloc_resource(
137-
img, num_tokens
138-
)
134+
self.tokenizer.init_imageItem_extral_params(img, multimodal_params, image_max_patch_num)
135+
record = await self._alloc_resource(img)
139136
img.uuid = record["id"]
140137
img.token_id = record["token_id"]
141138
img.token_num = record["token_num"]
@@ -234,9 +231,7 @@ async def generate(
234231
await self._log_req_header(request_headers, group_request_id)
235232
# 监控
236233

237-
prompt_ids = await self._encode(
238-
prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens
239-
)
234+
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
240235
prompt_tokens = len(prompt_ids)
241236
# 监控
242237
if group_request_id > 0:
@@ -307,15 +302,19 @@ async def _log_req_header(self, request_headers, group_request_id: int):
307302
return
308303

309304
async def _encode(
310-
self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, add_special_tokens: bool
305+
self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams
311306
):
312307
if isinstance(prompt, str):
313308
if self.enable_multimodal:
314309
assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!"
315-
await self._alloc_multimodal_resources(multimodal_params)
316-
prompt_ids = self.tokenizer.encode(prompt, multimodal_params, add_special_tokens=add_special_tokens)
310+
await self._alloc_multimodal_resources(
311+
multimodal_params, image_max_patch_num=sampling_params.image_max_patch_num
312+
)
313+
prompt_ids = self.tokenizer.encode(
314+
prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens
315+
)
317316
else:
318-
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
317+
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=sampling_params.add_special_tokens)
319318
return prompt_ids
320319

321320
# 这里的校验对多模态不是很充分, to do

0 commit comments

Comments
 (0)