@@ -110,6 +110,7 @@ def __init__(
110110 # connect cache server, calculate md5, alloc resource, return uuid
111111 async def _alloc_resource (self , img : ImageItem ):
112112 data = img .read ()
113+ # must after init_imageItem_extral_params
113114 num_tokens = self .tokenizer .get_image_token_length (img )
114115 md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (img .extra_params )))
115116 wait_time = 1
@@ -127,11 +128,11 @@ async def _alloc_resource(self, img: ImageItem):
127128 await asyncio .sleep (wait_time )
128129 wait_time = min (wait_time + 2 , 9 )
129130
130- async def _alloc_multimodal_resources (self , multimodal_params : MultimodalParams , image_max_patch_num ):
131+ async def _alloc_multimodal_resources (self , multimodal_params : MultimodalParams , sampling_params : SamplingParams ):
131132 # 只有 P 和 NORMAL 节点需要真的管理多模态资源
132133 if self .pd_mode .is_P_or_NORMAL ():
133134 for img in multimodal_params .images :
134- self .tokenizer .init_imageItem_extral_params (img , multimodal_params , image_max_patch_num )
135+ self .tokenizer .init_imageItem_extral_params (img , multimodal_params , sampling_params )
135136 record = await self ._alloc_resource (img )
136137 img .uuid = record ["id" ]
137138 img .token_id = record ["token_id" ]
@@ -151,15 +152,15 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
151152 img .token_num = None
152153 return
153154
154- def tokens (self , prompt , multimodal_params , kwargs = None ):
155+ def tokens (self , prompt , multimodal_params , samping_params = SamplingParams , kwargs = None ):
155156 kwargs = {} if kwargs is None else kwargs
156157 prompt_ids = self .tokenizer .encode (prompt , None , ** kwargs )
157158 image_tokens = 0
158159 img_count = 0
159- max_num = multimodal_params .max_num
160160 for img in multimodal_params .images :
161161 img_count += 1
162- image_tokens += self .tokenizer .get_image_token_length (img , max_num )
162+ self .tokenizer .init_imageItem_extral_params (img , multimodal_params , samping_params )
163+ image_tokens += self .tokenizer .get_image_token_length (img )
163164 return len (prompt_ids ) + image_tokens + img_count
164165
165166 async def loop_for_request (self ):
@@ -307,9 +308,7 @@ async def _encode(
307308 if isinstance (prompt , str ):
308309 if self .enable_multimodal :
309310 assert len (multimodal_params .images ) <= self .args .cache_capacity , "too many images!"
310- await self ._alloc_multimodal_resources (
311- multimodal_params , image_max_patch_num = sampling_params .image_max_patch_num
312- )
311+ await self ._alloc_multimodal_resources (multimodal_params , sampling_params )
313312 prompt_ids = self .tokenizer .encode (
314313 prompt , multimodal_params , add_special_tokens = sampling_params .add_special_tokens
315314 )
0 commit comments