@@ -41,63 +41,44 @@ def __init__(
4141 visual_model_rpc_ports ,
4242 ):
4343 self .args = args
44- self .visual_only = True if self .args .run_mode == "visual_only" else False
45- context = zmq .Context (2 )
46- self .id_gen = ReqIDGenerator ()
47- self .recv_from_httpserver = context .socket (zmq .PULL )
48- if self .visual_only :
49- self .recv_from_httpserver .bind (f"{ args .zmq_mode } 127.0.0.1:{ self .args .visual_only_port } " )
50- else :
51- self .recv_from_httpserver .bind (f"{ args .zmq_mode } 127.0.0.1:{ visual_port } " )
52- self .send_to_next_module = context .socket (zmq .PUSH ) # router or audio server (if --enable_multimodal_audio)
53- self .send_to_next_module .connect (f"{ args .zmq_mode } 127.0.0.1:{ next_module_port } " )
54- self .cache_client = rpyc .connect ("localhost" , cache_port , config = {"allow_pickle" : True })
55-
44+ self .remote_vit = args .enable_remote_vit
5645 self .cache_port = cache_port
5746 self .memory_cache = MemoryCacheWithRedis (args )
58- self .waiting_reqs_from_httpserver : List [GroupReqIndexes ] = []
59- self .waiting_reqs_visual_only : List [VisualOnlyReqIndexes ] = []
60- self .model_weightdir = args .model_dir
61- self .tp_world_size = args .tp
62- self .vit_dp = args .visual_dp
63- self .vit_tp = args .visual_tp
47+ self .waiting_reqs : List [GroupReqIndexes ] = []
6448 self .infer_batch_size = args .visual_infer_batch_size
6549 self .trust_remote_code = args .trust_remote_code
6650 self .visual_model_rpc_ports = visual_model_rpc_ports
67- self .shm_req_manager = ShmReqManager ()
68- self .tokenizer = get_tokenizer (args .model_dir , args .tokenizer_mode , trust_remote_code = args .trust_remote_code )
51+ self ._setup_connections ()
52+
53+ def _setup_connections (self ):
54+ context = zmq .Context (2 )
55+ if self .remote_vit :
56+ self .recv_from_httpserver .bind (f"tcp://*:{ self .args .remote_vit_port } " )
57+ else :
58+ self .recv_from_httpserver .bind (f"{ self .args .zmq_mode } 127.0.0.1:{ self .visual_port } " )
59+ self .send_to_next_module = context .socket (zmq .PUSH ) # router or audio server (if --enable_multimodal_audio)
60+ self .send_to_next_module .connect (f"{ self .args .zmq_mode } 127.0.0.1:{ self .next_module_port } " )
61+ self .cache_client = rpyc .connect ("localhost" , self .cache_port , config = {"allow_pickle" : True })
6962
7063 async def wait_to_model_ready (self ):
7164 # 待完成,需要读取config_server来起多个vit
7265 self .model_rpcs : List [List [VisualModelRpcClient ]] = [[] for _ in range (self .vit_dp )]
7366
74- for dp_rank_id in range (self .vit_dp ):
67+ for dp_rank_id in range (self .args . visual_dp ):
7568 tp_ports_each_dp = self .visual_model_rpc_ports [dp_rank_id ]
76- for tp_rank_id in range (self .vit_tp ):
77- device_id = self .args .visual_gpu_ids [dp_rank_id * self .vit_tp + tp_rank_id ]
69+ for tp_rank_id in range (self .args . visual_tp ):
70+ device_id = self .args .visual_gpu_ids [dp_rank_id * self .args . visual_tp + tp_rank_id ]
7871 rpc_model = await start_model_process (
79- port = tp_ports_each_dp [tp_rank_id ], vit_tp = self .vit_tp , device_id = device_id
72+ port = tp_ports_each_dp [tp_rank_id ], vit_tp = self .args . visual_tp , device_id = device_id
8073 )
8174 self .model_rpcs [dp_rank_id ].append (rpc_model )
8275
8376 init_model_ret = []
84- for dp_rank_id in range (self .vit_dp ): # async init model process
85- for tp_rank_id in range (self .vit_tp ):
77+ for dp_rank_id in range (self .args . visual_dp ): # async init model process
78+ for tp_rank_id in range (self .args . visual_tp ):
8679 kvargs = {
87- "weight_dir" : self .model_weightdir ,
88- "trust_remote_code" : self .trust_remote_code ,
89- "vit_dp" : self .vit_dp ,
90- "vit_tp" : self .vit_tp ,
91- "cache_port" : self .cache_port ,
9280 "tp_rank_id" : tp_rank_id ,
9381 "dp_rank_id" : dp_rank_id ,
94- "vit_rank_id" : dp_rank_id * self .vit_tp + tp_rank_id ,
95- "data_type" : self .args .data_type ,
96- "visual_nccl_port" : self .args .visual_nccl_ports [dp_rank_id ],
97- "visual_gpu_ids" : self .args .visual_gpu_ids ,
98- "quant_type" : self .args .vit_quant_type ,
99- "quant_cfg" : self .args .vit_quant_cfg ,
100- "max_batch_size" : min (self .infer_batch_size // self .vit_dp , 1 ),
10182 }
10283 init_model_ret .append (self .model_rpcs [dp_rank_id ][tp_rank_id ].init_model (kvargs ))
10384 await asyncio .gather (* init_model_ret )
@@ -108,10 +89,10 @@ async def infer_imgs(self, images: List[ImageItem]):
10889 return
10990
11091 tasks = []
111- for vit_dp_rank in range (self .vit_dp ):
112- assigned_images = [images [i ] for i in range (vit_dp_rank , len (images ), self .vit_dp )]
92+ for vit_dp_rank in range (self .args . visual_dp ):
93+ assigned_images = [images [i ] for i in range (vit_dp_rank , len (images ), self .args . visual_dp )]
11394 if assigned_images :
114- for vit_tp_rank in range (self .vit_tp ):
95+ for vit_tp_rank in range (self .args . visual_tp ):
11596 task = asyncio .create_task (self .model_rpcs [vit_dp_rank ][vit_tp_rank ].encode (assigned_images ))
11697 tasks .append (task )
11798
@@ -120,13 +101,13 @@ async def infer_imgs(self, images: List[ImageItem]):
120101
121102 async def loop_for_fwd (self ):
122103 while True :
123- if len (self .waiting_reqs_from_httpserver ) == 0 :
104+ if len (self .waiting_reqs ) == 0 :
124105 await asyncio .sleep (0.01 ) # 10ms
125106 else :
126107 processing_group_reqs = []
127108 images_need_infer = []
128- while len (self .waiting_reqs_from_httpserver ) > 0 :
129- group_req_indexes = self .waiting_reqs_from_httpserver .pop (0 )
109+ while len (self .waiting_reqs ) > 0 :
110+ group_req_indexes = self .waiting_reqs .pop (0 )
130111 shm_req = self .shm_req_manager .get_req_obj_by_index (group_req_indexes .shm_req_indexes [0 ])
131112 is_aborted = shm_req .is_aborted
132113 self .shm_req_manager .put_back_req_obj (shm_req )
@@ -167,16 +148,31 @@ async def loop_for_fwd(self):
167148 processing_group_reqs = []
168149 images_need_infer = []
169150
151+ def _recv_reqs (self ):
152+ if self .remote_vit :
153+ recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
154+ for img in recv_req .multimodal_params .images :
155+ data = img ._preload_data
156+ img ._preload_data = None
157+ md5sum = hashlib .md5 (data ).hexdigest ()
158+ uid = int (md5sum , 16 )
159+ # create_shm(get_shm_name_data(uid), data)
160+ self .cache_client .root .set_items_data ([uid ])
161+
162+ return recv_req
163+ else :
164+ return self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
165+
170166 async def loop_for_netio_req (self ):
171167 if not hasattr (self , "visual_recv_max_count" ):
172168 self .visual_recv_max_count = 64
173169
174170 while True :
175171 try :
176172 for _ in range (self .visual_recv_max_count ):
177- recv_req : GroupReqIndexes = self .recv_from_httpserver . recv_pyobj ( zmq . NOBLOCK )
173+ recv_req : GroupReqIndexes = self ._recv_reqs ( )
178174 if isinstance (recv_req , GroupReqIndexes ):
179- self .waiting_reqs_from_httpserver .append (recv_req )
175+ self .waiting_reqs .append (recv_req )
180176 else :
181177 assert False , f"Error Req Inf { recv_req } "
182178 self .visual_recv_max_count = min (self .visual_recv_max_count * 1.3 , 256 )
@@ -211,103 +207,6 @@ async def loop_for_fwd_visual_only(self):
211207 # 在这里release这个image,ref-1
212208 logger .info (f"req-id { visual_req .group_req_id } has been release ok" )
213209
214- async def _initialize_multimodal_metadata (
215- self , multimodal_params : MultimodalParams , sampling_params : SamplingParams
216- ):
217- for img in multimodal_params .images :
218- self .tokenizer .init_imageitem_extral_params (img , multimodal_params , sampling_params )
219- data = img .read ()
220- # must after init_imageitem_extral_params
221- token_num = self .tokenizer .get_image_token_length (img )
222- md5sum = "{}_{}" .format (
223- hashlib .md5 (data ).hexdigest (),
224- hashlib .md5 (pickle .dumps (img .extra_params , protocol = 4 )).hexdigest (),
225- )
226- img .uuid = int (md5sum , 16 )
227- img .token_num = token_num
228-
229- async def _log_req_header (self , request_headers , group_request_id : int , image_count : int ):
230-
231- x_request_id = request_headers .get ("X-Request-Id" , "" )
232- x_session_id = request_headers .get ("X-Session-Id" , "" )
233-
234- format_in_time = datetime .datetime .fromtimestamp (time .time ()).strftime ("%Y-%m-%d %H:%M:%S" )
235- logger .info (
236- f"recieved req X-Request-Id:{ x_request_id } "
237- f"X-Session-Id:{ x_session_id } start_time:{ format_in_time } "
238- f"lightllm_req_id:{ group_request_id } "
239- f"image_count:{ image_count } "
240- )
241- return
242-
243- def alloc_req_id (self , sampling_params , is_health_req : bool = False ):
244- # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性
245- # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置
246- # health 请求 request_id 为负数,直接返回
247- if is_health_req :
248- return sampling_params .group_request_id
249- group_request_id = self .id_gen .generate_id ()
250-
251- sampling_params .group_request_id = group_request_id
252- return group_request_id
253-
254- # async def generate(
255- # self,
256- # sampling_params: SamplingParams,
257- # multimodal_params: MultimodalParams,
258- # request: Request,
259- # is_health_req: bool = False,
260- # ) -> Tuple[int, str, dict, FinishStatus]:
261-
262- # request_headers = request.headers if request is not None else {}
263- # group_request_id = self.alloc_req_id(sampling_params, is_health_req)
264-
265- # try:
266- # await multimodal_params.verify_and_preload(request)
267- # image_count = len(multimodal_params.images)
268- # # 记录请求到达的相关信息
269- # await self._log_req_header(request_headers, group_request_id, image_count)
270- # assert (
271- # len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
272- # ), "too many multimodal items!"
273-
274- # await self._initialize_multimodal_metadata(multimodal_params, sampling_params)
275-
276- # visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id,
277- # multimodal_params=multimodal_params)
278- # self.waiting_reqs_visual_only.append(visual_req_status)
279-
280- # except Exception as e:
281- # logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")
282- # await self.abort(group_request_id, multimodal_params)
283- # raise e
284- # return
285-
286- async def abort (self , group_req_id : int , multimodal_params : MultimodalParams ):
287- logger .warning (f"aborted group_request_id { group_req_id } " )
288- for img in multimodal_params .images :
289- img .is_abort = True
290- return
291-
292- async def loop_for_netio_req (self ):
293- if not hasattr (self , "visual_recv_max_count" ):
294- self .visual_recv_max_count = 64
295-
296- while True :
297- try :
298- for _ in range (self .visual_recv_max_count ):
299- recv_req : GroupReqIndexes = self .recv_from_httpserver .recv_pyobj (zmq .NOBLOCK )
300- print (f"recv_req is { recv_req } " )
301- if isinstance (recv_req , GroupReqIndexes ):
302- self .waiting_reqs_from_httpserver .append (recv_req )
303- else :
304- assert False , f"Error Req Inf { recv_req } "
305- self .visual_recv_max_count = min (self .visual_recv_max_count * 1.3 , 256 )
306- except zmq .ZMQError :
307- # 当队列已经开始清空的时候,将一次接受数量下调
308- self .visual_recv_max_count = 64
309- await asyncio .sleep (0.01 )
310-
311210 def clean_up (self ):
312211 for model_rpc in self .model_rpcs :
313212 model_rpc .rpc_server_process .kill ()
0 commit comments