@@ -57,6 +57,7 @@ def __init__(
5757 self .trust_remote_code = args .trust_remote_code
5858 self .args = args
5959 self .visual_model_rpc_ports = visual_model_rpc_ports
60+ self .send_batch_size = min (args .visual_send_batch_size , args .cache_capacity )
6061 self .shm_req_manager = ShmReqManager ()
6162
6263 async def wait_to_model_ready (self ):
@@ -117,6 +118,18 @@ async def loop_for_fwd(self):
117118 else :
118119 processing_group_reqs = []
119120 images_need_infer = []
121+ ready_to_send = []
122+
123+ def flush_ready (force : bool = False ):
124+ if not ready_to_send :
125+ return
126+ if not force and len (ready_to_send ) < self .send_batch_size :
127+ return
128+
129+ for group_req_indexes in ready_to_send :
130+ self .send_to_next_module .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
131+ ready_to_send .clear ()
132+
120133 while len (self .waiting_reqs ) > 0 :
121134 group_req_indexes = self .waiting_reqs .pop (0 )
122135 shm_req = self .shm_req_manager .get_req_obj_by_index (group_req_indexes .shm_req_indexes [0 ])
@@ -146,24 +159,25 @@ async def loop_for_fwd(self):
146159 if len (images_need_infer ) == self .infer_batch_size :
147160 await self .infer_imgs (images_need_infer )
148161 images_need_infer = []
149- for _group_req_indexes in processing_group_reqs :
150- self .send_to_next_module .send_pyobj (
151- _group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL
152- )
162+ ready_to_send .extend (processing_group_reqs )
153163 processing_group_reqs = []
164+ flush_ready (force = False )
154165
155166 if len (images_need_infer ) == 0 :
156- self .send_to_next_module .send_pyobj (group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
167+ ready_to_send .append (group_req_indexes )
168+ flush_ready (force = False )
157169 else :
158170 processing_group_reqs .append (group_req_indexes )
159171
160172 if len (images_need_infer ) > 0 :
161173 await self .infer_imgs (images_need_infer )
162- for _group_req_indexes in processing_group_reqs :
163- self .send_to_next_module .send_pyobj (_group_req_indexes , protocol = pickle .HIGHEST_PROTOCOL )
164- processing_group_reqs = []
165174 images_need_infer = []
166175
176+ # 这些处理完 image 的 group 也 ready 了
177+ ready_to_send .extend (processing_group_reqs )
178+ processing_group_reqs = []
179+ flush_ready (force = True )
180+
167181 async def loop_for_netio_req (self ):
168182 if not hasattr (self , "visual_recv_max_count" ):
169183 self .visual_recv_max_count = 64
0 commit comments