Skip to content

Commit 7f580f1

Browse files
author
sangchengmeng
committed
0401fix
1 parent 339d98e commit 7f580f1

File tree

7 files changed

+57
-46
lines changed

7 files changed

+57
-46
lines changed

lightllm/models/internvl/model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
InternVLLlamaPreAndPostLayerWeight,
1212
InternVLPhi3PreAndPostLayerWeight,
1313
)
14+
from lightllm.server.core.objs import SamplingParams
1415
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
1516
from lightllm.models.llava.llava_visual import LlavaVisionModel
1617

@@ -40,20 +41,29 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4041
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
4142
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4243

43-
def init_imageItem_extral_params(self, img: ImageItem, num_images):
44-
if img.extra_params["image_patch_max_num"] > 0:
44+
def init_imageItem_extral_params(self, img: ImageItem, multi_params: MultimodalParams, image_max_patch_num: int):
45+
if image_max_patch_num >= 0:
46+
img.extra_params["image_patch_max_num"] = image_max_patch_num
4547
return
46-
if num_images == 1:
47-
img.extra_params["image_patch_max_num"] = 12
48-
elif num_images > 1 and num_images <= 6:
49-
img.extra_params["image_patch_max_num"] = 6
50-
elif num_images > 6:
51-
img.extra_params["image_patch_max_num"] = 0
48+
elif os.getenv("MAX_PATCH_NUM"):
49+
img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM"))
50+
return
51+
else:
52+
num_images = len(multi_params.images)
53+
if num_images == 1:
54+
img.extra_params["image_patch_max_num"] = 12
55+
elif num_images > 1 and num_images <= 6:
56+
img.extra_params["image_patch_max_num"] = 6
57+
elif num_images > 6:
58+
img.extra_params["image_patch_max_num"] = 0
5259
return
5360

5461
def get_image_token_length(self, img: ImageItem):
5562
return (
56-
self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length
63+
self.get_image_patch_func(
64+
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
65+
)
66+
* self.image_length
5767
)
5868

5969
# only change the impl of the encode func:

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def flash_attention_v3_fwd(
204204

205205
except ImportError:
206206
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
207+
_flash_attn_v3_available = False
207208

208209

209210
def flash_attention_fwd(q, k, v, o):

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
top_p: float = None,
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
40+
image_max_patch_num: int = -1,
4041
max_new_tokens: int = 16,
4142
min_new_tokens: int = 1,
4243
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
@@ -75,6 +76,7 @@ def __init__(
7576
self.top_p = top_p if top_p is not None else SamplingParams._top_p
7677
self.top_k = top_k if top_k is not None else SamplingParams._top_k
7778
self.ignore_eos = ignore_eos
79+
self.image_max_patch_num = image_max_patch_num
7880
self.max_new_tokens = max_new_tokens
7981
self.min_new_tokens = min_new_tokens
8082
self.stop_sequences = stop_sequences if stop_sequences is not None else SamplingParams._stop_sequences
@@ -254,6 +256,7 @@ def to_dict(self):
254256
ret["temperature"] = self.temperature
255257
ret["top_p"] = self.top_p
256258
ret["top_k"] = self.top_k
259+
ret["image_max_patch_num"] = self.image_max_patch_num
257260
ret["min_new_tokens"] = self.min_new_tokens
258261
ret["ignore_eos"] = self.ignore_eos
259262
ret["max_new_tokens"] = self.max_new_tokens

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class SamplingParams(ctypes.Structure):
249249
("top_p", ctypes.c_float),
250250
("top_k", ctypes.c_int),
251251
("ignore_eos", ctypes.c_bool),
252+
("image_max_patch_num", ctypes.c_int),
252253
("max_new_tokens", ctypes.c_int),
253254
("min_new_tokens", ctypes.c_int),
254255
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -294,6 +295,7 @@ def init(self, tokenizer, **kwargs):
294295
self.top_p = kwargs.get("top_p", SamplingParams._top_p)
295296
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
296297
self.ignore_eos = kwargs.get("ignore_eos", False)
298+
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
297299
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
298300
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
299301
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
@@ -424,6 +426,7 @@ def to_dict(self):
424426
"top_p": self.top_p,
425427
"top_k": self.top_k,
426428
"ignore_eos": self.ignore_eos,
429+
"image_max_patch_num": self.image_max_patch_num,
427430
"max_new_tokens": self.max_new_tokens,
428431
"min_new_tokens": self.min_new_tokens,
429432
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),

lightllm/server/httpserver/manager.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def __init__(
108108
return
109109

110110
# connect cache server, calculate md5, alloc resource, return uuid
111-
async def _alloc_resource(self, img:ImageItem, num_tokens):
111+
async def _alloc_resource(self, img: ImageItem):
112112
data = img.read()
113+
num_tokens = self.tokenizer.get_image_token_length(img)
113114
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
114115
wait_time = 1
115116
while True:
@@ -126,16 +127,12 @@ async def _alloc_resource(self, img:ImageItem, num_tokens):
126127
await asyncio.sleep(wait_time)
127128
wait_time = min(wait_time + 2, 9)
128129

129-
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams):
130+
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, image_max_patch_num):
130131
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
131132
if self.pd_mode.is_P_or_NORMAL():
132-
num_images = len(multimodal_params.images)
133133
for img in multimodal_params.images:
134-
self.tokenizer.init_imageItem_extral_params(img, num_images)
135-
num_tokens = self.tokenizer.get_image_token_length(img)
136-
record = await self._alloc_resource(
137-
img, num_tokens
138-
)
134+
self.tokenizer.init_imageItem_extral_params(img, multimodal_params, image_max_patch_num)
135+
record = await self._alloc_resource(img)
139136
img.uuid = record["id"]
140137
img.token_id = record["token_id"]
141138
img.token_num = record["token_num"]
@@ -234,9 +231,7 @@ async def generate(
234231
await self._log_req_header(request_headers, group_request_id)
235232
# 监控
236233

237-
prompt_ids = await self._encode(
238-
prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens
239-
)
234+
prompt_ids = await self._encode(prompt, multimodal_params, sampling_params)
240235
prompt_tokens = len(prompt_ids)
241236
# 监控
242237
if group_request_id > 0:
@@ -307,15 +302,19 @@ async def _log_req_header(self, request_headers, group_request_id: int):
307302
return
308303

309304
async def _encode(
310-
self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, add_special_tokens: bool
305+
self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams, sampling_params: SamplingParams
311306
):
312307
if isinstance(prompt, str):
313308
if self.enable_multimodal:
314309
assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!"
315-
await self._alloc_multimodal_resources(multimodal_params)
316-
prompt_ids = self.tokenizer.encode(prompt, multimodal_params, add_special_tokens=add_special_tokens)
310+
await self._alloc_multimodal_resources(
311+
multimodal_params, image_max_patch_num=sampling_params.image_max_patch_num
312+
)
313+
prompt_ids = self.tokenizer.encode(
314+
prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens
315+
)
317316
else:
318-
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
317+
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=sampling_params.add_special_tokens)
319318
return prompt_ids
320319

321320
# 这里的校验对多模态不是很充分, to do

lightllm/server/multimodal_params.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
class ImageItem:
11-
1211
def __init__(self, **kwargs):
1312
self._type = kwargs["type"]
1413
self._data = kwargs["data"]
@@ -22,7 +21,7 @@ def __init__(self, **kwargs):
2221
self.image_h = 0
2322

2423
self._preload_data = None
25-
self.extra_params = {"image_patch_max_num": kwargs.get("max_num", None)}
24+
self.extra_params = {}
2625

2726
def preload(self):
2827
try:
@@ -74,12 +73,8 @@ class MultimodalParams:
7473
def __init__(
7574
self,
7675
images: List[dict] = [],
77-
max_num: int = -1,
7876
) -> None:
7977
self.images = [ImageItem(**i) for i in images]
80-
max_num = int(os.getenv("MAX_PATCH_NUM", max_num))
81-
for image in self.images:
82-
image.extra_params["image_patch_max_num"] = max_num
8378

8479
def verify_and_preload(self):
8580
for image in self.images:

lightllm/server/visualserver/manager.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightllm.server.core.objs import ShmReqManager
1111

1212
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13-
13+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
1414
from .model_infer.model_rpc import start_model_process, VisualModelRpcClient
1515
from lightllm.utils.log_utils import init_logger
1616
from lightllm.utils.graceful_utils import graceful_registry
@@ -80,16 +80,16 @@ async def wait_to_model_ready(self):
8080
await asyncio.gather(*init_model_ret)
8181
return
8282

83-
async def infer_imgs(self, uuids):
84-
if len(uuids) == 0:
83+
async def infer_imgs(self, images: List[ImageItem]):
84+
if len(images) == 0:
8585
return
8686

8787
tasks = []
8888
for vit_dp_rank in range(self.vit_dp):
89-
assigned_uuids = [uuids[i] for i in range(vit_dp_rank, len(uuids), self.vit_dp)]
90-
if assigned_uuids:
89+
assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)]
90+
if assigned_images:
9191
for vit_tp_rank in range(self.vit_tp):
92-
task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_uuids))
92+
task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images))
9393
tasks.append(task)
9494

9595
await asyncio.gather(*tasks)
@@ -101,7 +101,7 @@ async def loop_for_fwd(self):
101101
await asyncio.sleep(0.01) # 10ms
102102
else:
103103
processing_group_reqs = []
104-
uuids_need_infer = []
104+
images_need_infer = []
105105
while len(self.waiting_reqs) > 0:
106106
group_req_indexes = self.waiting_reqs.pop(0)
107107
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
@@ -117,27 +117,27 @@ async def loop_for_fwd(self):
117117
multimodal_params = group_req_indexes.multimodal_params
118118

119119
for img in multimodal_params.images:
120-
# if not self.cache_client.root.get_item_embed(img.uuid):
121-
# uuids_need_infer.append(img.uuid)
120+
if not self.cache_client.root.get_item_embed(img.uuid):
121+
images_need_infer.append(img)
122122

123-
if len(multimodal_params.images) == self.infer_batch_size:
124-
await self.infer_imgs(multimodal_params.images)
125-
# uuids_need_infer = []
123+
if len(images_need_infer) == self.infer_batch_size:
124+
await self.infer_imgs(images_need_infer)
125+
images_need_infer = []
126126
for _group_req_indexes in processing_group_reqs:
127127
self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
128128
processing_group_reqs = []
129129

130-
if len(multimodal_params.images) == 0:
130+
if len(images_need_infer) == 0:
131131
self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
132132
else:
133133
processing_group_reqs.append(group_req_indexes)
134134

135-
if len(multimodal_params.images) > 0:
136-
await self.infer_imgs(multimodal_params.images)
135+
if len(images_need_infer) > 0:
136+
await self.infer_imgs(images_need_infer)
137137
for _group_req_indexes in processing_group_reqs:
138138
self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
139139
processing_group_reqs = []
140-
# uuids_need_infer = []
140+
images_need_infer = []
141141

142142
async def loop_for_netio_req(self):
143143
while True:

0 commit comments

Comments
 (0)