@@ -108,8 +108,9 @@ def __init__(
108108 return
109109
110110 # connect cache server, calculate md5, alloc resource, return uuid
111- async def _alloc_resource (self , img :ImageItem , num_tokens ):
111+ async def _alloc_resource (self , img : ImageItem ):
112112 data = img .read ()
113+ num_tokens = self .tokenizer .get_image_token_length (img )
113114 md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (img .extra_params )))
114115 wait_time = 1
115116 while True :
@@ -126,16 +127,12 @@ async def _alloc_resource(self, img:ImageItem, num_tokens):
126127 await asyncio .sleep (wait_time )
127128 wait_time = min (wait_time + 2 , 9 )
128129
129- async def _alloc_multimodal_resources (self , multimodal_params : MultimodalParams ):
130+ async def _alloc_multimodal_resources (self , multimodal_params : MultimodalParams , image_max_patch_num ):
130131 # 只有 P 和 NORMAL 节点需要真的管理多模态资源
131132 if self .pd_mode .is_P_or_NORMAL ():
132- num_images = len (multimodal_params .images )
133133 for img in multimodal_params .images :
134- self .tokenizer .init_imageItem_extral_params (img , num_images )
135- num_tokens = self .tokenizer .get_image_token_length (img )
136- record = await self ._alloc_resource (
137- img , num_tokens
138- )
134+ self .tokenizer .init_imageItem_extral_params (img , multimodal_params , image_max_patch_num )
135+ record = await self ._alloc_resource (img )
139136 img .uuid = record ["id" ]
140137 img .token_id = record ["token_id" ]
141138 img .token_num = record ["token_num" ]
@@ -234,9 +231,7 @@ async def generate(
234231 await self ._log_req_header (request_headers , group_request_id )
235232 # 监控
236233
237- prompt_ids = await self ._encode (
238- prompt , multimodal_params , add_special_tokens = sampling_params .add_special_tokens
239- )
234+ prompt_ids = await self ._encode (prompt , multimodal_params , sampling_params )
240235 prompt_tokens = len (prompt_ids )
241236 # 监控
242237 if group_request_id > 0 :
@@ -307,15 +302,19 @@ async def _log_req_header(self, request_headers, group_request_id: int):
307302 return
308303
309304 async def _encode (
310- self , prompt : Union [str , List [int ]], multimodal_params : MultimodalParams , add_special_tokens : bool
305+ self , prompt : Union [str , List [int ]], multimodal_params : MultimodalParams , sampling_params : SamplingParams
311306 ):
312307 if isinstance (prompt , str ):
313308 if self .enable_multimodal :
314309 assert len (multimodal_params .images ) <= self .args .cache_capacity , "too many images!"
315- await self ._alloc_multimodal_resources (multimodal_params )
316- prompt_ids = self .tokenizer .encode (prompt , multimodal_params , add_special_tokens = add_special_tokens )
310+ await self ._alloc_multimodal_resources (
311+ multimodal_params , image_max_patch_num = sampling_params .image_max_patch_num
312+ )
313+ prompt_ids = self .tokenizer .encode (
314+ prompt , multimodal_params , add_special_tokens = sampling_params .add_special_tokens
315+ )
317316 else :
318- prompt_ids = self .tokenizer .encode (prompt , add_special_tokens = add_special_tokens )
317+ prompt_ids = self .tokenizer .encode (prompt , add_special_tokens = sampling_params . add_special_tokens )
319318 return prompt_ids
320319
321320 # 这里的校验对多模态不是很充分, to do
0 commit comments