1010import datetime
1111import pickle
1212from frozendict import frozendict
13+ import concurrent .futures
1314
1415asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
1516from typing import Union , List , Tuple , Dict , Optional
3233from lightllm .utils .statics_utils import MovingAverage
3334from lightllm .utils .config_utils import get_vocab_size
3435from lightllm .utils .envs_utils import get_unique_server_name
36+ from lightllm .utils .infer_utils import calculate_cpu_time_async , calculate_cpu_time_sync
3537from rpyc .utils .classic import obtain
3638
3739logger = init_logger (__name__ )
@@ -112,13 +114,19 @@ def __init__(
112114 # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
113115 self .latest_success_infer_time_mark = SharedInt (f"{ get_unique_server_name ()} _latest_success_infer_time_mark" )
114116 self .latest_success_infer_time_mark .set_value (int (time .time ()))
117+
118+ # 线程池用于创建multimodal resource alloc
119+ self .enable_concurrent_alloc = self .args .enable_concurrent_alloc
120+ self .max_concurrent = self .args .concurrent_alloc_workers * 48
121+ if self .enable_concurrent_alloc :
122+ self .executor = concurrent .futures .ThreadPoolExecutor (max_workers = self .args .concurrent_alloc_workers )
115123 return
116124
117125 async def _alloc_resource (self , items , md5sums , token_nums , datas ):
118-
119126 while True :
127+ t1 = time .time ()
120128 records = obtain (self .cache_client .root .alloc (md5sums , token_nums ))
121-
129+ logger . info ( f"cache manager batch alloc time: { ( time . time () - t1 ) * 1000 } ms" )
122130 if records is None :
123131 await asyncio .sleep (0.1 )
124132 continue
@@ -142,37 +150,139 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas):
142150 self .cache_client .root .set_items_data (update_data_ids )
143151 return
144152
153+ async def _alloc_resource_v2 (self , items , md5sums , token_nums , datas ):
154+ batch_requests = [(md5sum , token_num ) for md5sum , token_num in zip (md5sums , token_nums )]
155+ while True :
156+ t1 = time .time ()
157+ req_blob = pickle .dumps (batch_requests )
158+ res_blob = self .cache_client .root .alloc_v2 (req_blob )
159+ records = pickle .loads (res_blob )
160+ logger .info (f"cache manager batch alloc time: { (time .time () - t1 )* 1000 } ms" )
161+ if records is None :
162+ await asyncio .sleep (0.1 )
163+ continue
164+
165+ uid_list = []
166+ for item , rec in zip (items , records ):
167+ item .uuid = rec ["id" ]
168+ item .token_id = rec ["token_id" ]
169+ item .token_num = rec ["token_num" ]
170+ uid_list .append (rec ["id" ])
171+
172+ uid_blob = pickle .dumps (uid_list )
173+ ready_flags = self .cache_client .root .get_items_data_v2 (uid_blob )
174+ ready_flags = pickle .loads (ready_flags )
175+
176+ max_concurrent_shm = min (len (items ), self .max_concurrent ) # 限制最大并发
177+ semaphore = asyncio .Semaphore (max_concurrent_shm )
178+
179+ async def create_shm_with_limit (uid , data ):
180+ async with semaphore :
181+ loop = asyncio .get_event_loop ()
182+ return await loop .run_in_executor (self .executor , create_shm , get_shm_name_data (uid ), data )
183+
184+ update_data_ids = []
185+ shm_tasks = []
186+ for uid , ready , data in zip (uid_list , ready_flags , datas ):
187+ if not ready :
188+ task = create_shm_with_limit (uid , data )
189+ shm_tasks .append (task )
190+ update_data_ids .append (uid )
191+
192+ if len (shm_tasks ):
193+ t_shm = time .time ()
194+ await asyncio .gather (* shm_tasks )
195+ logger .info (f"concurrent create shm time: { (time .time () - t_shm )* 1000 } ms" )
196+
197+ if update_data_ids :
198+ update_dataids_blob = pickle .dumps (update_data_ids )
199+ self .cache_client .root .set_items_data_v2 (update_dataids_blob )
200+ return
201+
202+ @calculate_cpu_time_async (show = True )
145203 async def _alloc_multimodal_resources (self , multimodal_params : MultimodalParams , sampling_params : SamplingParams ):
146204 # 只有 P 和 NORMAL 节点需要真的管理多模态资源
147205 if self .pd_mode .is_P_or_NORMAL ():
148206 # 这里的锁是为了 防止多个含有多张图片的请求 同时申请的record数量 大于cache_capacity,从而造成死锁的问题。
149207 # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10,
150208 # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。
151209 async with self ._resource_lock :
152- items , md5sums , tokens_nums , datas = [], [], [], []
153- for img in multimodal_params .images :
154- self .tokenizer .init_imageitem_extral_params (img , multimodal_params , sampling_params )
155- data = img .read ()
156- # must after init_imageitem_extral_params
157- token_num = self .tokenizer .get_image_token_length (img )
158- md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (img .extra_params )))
159- md5sums .append (md5sum )
160- tokens_nums .append (token_num )
161- datas .append (data )
162- items .append (img )
163- for audio in multimodal_params .audios :
164- self .tokenizer .init_audioitem_extral_params (audio , multimodal_params , sampling_params )
165- data = audio .read ()
166- token_num = self .tokenizer .get_audio_token_length (audio )
167- md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (audio .extra_params )))
168- md5sums .append (md5sum )
169- tokens_nums .append (token_num )
170- datas .append (data )
171- items .append (audio )
172-
173- await self ._alloc_resource (items , md5sums , tokens_nums , datas )
210+ if self .enable_concurrent_alloc :
211+ await self ._alloc_multimodal_resources_v2 (multimodal_params , sampling_params )
212+ else :
213+ await self ._alloc_multimodal_resources_v1 (multimodal_params , sampling_params )
214+
174215 return
175216
217+ async def _alloc_multimodal_resources_v1 (
218+ self , multimodal_params : MultimodalParams , sampling_params : SamplingParams
219+ ):
220+ items , md5sums , tokens_nums , datas = [], [], [], []
221+ for img in multimodal_params .images :
222+ self .tokenizer .init_imageitem_extral_params (img , multimodal_params , sampling_params )
223+ data = img .read ()
224+ # must after init_imageitem_extral_params
225+ token_num = self .tokenizer .get_image_token_length (img )
226+ md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (img .extra_params )))
227+ md5sums .append (md5sum )
228+ tokens_nums .append (token_num )
229+ datas .append (data )
230+ items .append (img )
231+ for audio in multimodal_params .audios :
232+ self .tokenizer .init_audioitem_extral_params (audio , multimodal_params , sampling_params )
233+ data = audio .read ()
234+ token_num = self .tokenizer .get_audio_token_length (audio )
235+ md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (audio .extra_params )))
236+ md5sums .append (md5sum )
237+ tokens_nums .append (token_num )
238+ datas .append (data )
239+ items .append (audio )
240+
241+ await self ._alloc_resource (items , md5sums , tokens_nums , datas )
242+
243+ async def _alloc_multimodal_resources_v2 (
244+ self , multimodal_params : MultimodalParams , sampling_params : SamplingParams
245+ ):
246+ all_items = multimodal_params .images + multimodal_params .audios
247+ if not all_items :
248+ return
249+ loop = asyncio .get_event_loop ()
250+
251+ def _process_item (item , multimodal_params , sampling_params ):
252+ """初始化item参数、读取数据并计算MD5"""
253+ if isinstance (item , ImageItem ): # 图片
254+ self .tokenizer .init_imageitem_extral_params (item , multimodal_params , sampling_params )
255+ elif isinstance (item , AudioItem ):
256+ self .tokenizer .init_audioitem_extral_params (item , multimodal_params , sampling_params )
257+
258+ data = item .read ()
259+ md5sum = hashlib .md5 (data ).hexdigest () + "_" + str (hash (frozendict (item .extra_params )))
260+ return data , md5sum
261+
262+ chunk_size = self .max_concurrent # 可以根据需要调整
263+ for i in range (0 , len (all_items ), chunk_size ):
264+ chunk = all_items [i : i + chunk_size ]
265+
266+ # 并发处理chunk内的所有item
267+ process_tasks = [
268+ loop .run_in_executor (self .executor , _process_item , item , multimodal_params , sampling_params )
269+ for item in chunk
270+ ]
271+ chunk_results = await asyncio .gather (* process_tasks )
272+ chunk_items , chunk_md5sums , chunk_tokens_nums , chunk_datas = [], [], [], []
273+ for j , item in enumerate (chunk ):
274+ data , md5sum = chunk_results [j ]
275+ if isinstance (item , ImageItem ):
276+ token_num = self .tokenizer .get_image_token_length (item )
277+ elif isinstance (item , AudioItem ):
278+ token_num = self .tokenizer .get_audio_token_length (item )
279+ chunk_items .append (item )
280+ chunk_md5sums .append (md5sum )
281+ chunk_tokens_nums .append (token_num )
282+ chunk_datas .append (data )
283+
284+ await self ._alloc_resource_v2 (chunk_items , chunk_md5sums , chunk_tokens_nums , chunk_datas )
285+
176286 async def _release_multimodal_resources (self , multimodal_params : MultimodalParams ):
177287 # 只有 P 和 NORMAL 节点需要真的管理多模态资源
178288 if self .pd_mode .is_P_or_NORMAL ():
@@ -193,7 +303,11 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
193303 audio .token_id = None
194304 audio .token_num = None
195305 if ids_to_release :
196- self .cache_client .root .release (ids_to_release )
306+ if self .enable_concurrent_alloc :
307+ release_id_blobs = pickle .dumps (ids_to_release )
308+ self .cache_client .root .release_v2 (release_id_blobs )
309+ else :
310+ self .cache_client .root .release (ids_to_release )
197311 return
198312
199313 def tokens (self , prompt , multimodal_params , samping_params : SamplingParams , kwargs = None ):
@@ -341,7 +455,6 @@ async def generate(
341455 return
342456
343457 async def _log_req_header (self , request_headers , group_request_id : int ):
344-
345458 x_request_id = request_headers .get ("X-Request-Id" , "" )
346459 x_session_id = request_headers .get ("X-Session-Id" , "" )
347460
@@ -436,7 +549,6 @@ async def transfer_to_next_module(
436549 self ,
437550 group_req_objs : Optional [GroupReqObjs ] = None ,
438551 ):
439-
440552 if self .pd_mode == NodeRole .P :
441553 if self .enable_multimodal :
442554 self .send_to_visual .send_pyobj (
@@ -483,7 +595,6 @@ async def _wait_to_token_package(
483595 req_status : "ReqStatus" ,
484596 request : Request ,
485597 ):
486-
487598 event = req_status .event
488599 unfinished_count = sampling_params .best_of
489600 out_token_counter = 0
@@ -589,7 +700,6 @@ async def recycle_resource_loop(self):
589700 pre_time_mark = time .time ()
590701
591702 while True :
592-
593703 try :
594704 await asyncio .wait_for (self .recycle_event .wait (), timeout = 0.02 )
595705 except asyncio .TimeoutError :
@@ -660,7 +770,6 @@ async def handle_loop(self):
660770
661771 for _ in range (read_token_count ):
662772 if not req .out_tokens_queue .is_empty ():
663-
664773 text , src_index , special , count_output_tokens = req .out_tokens_queue .peek ()
665774 req .cumlogprob += float (req .shm_logprobs .arr [src_index ])
666775 metadata = {
0 commit comments