|
| 1 | +import os |
1 | 2 | import zmq |
2 | 3 | import zmq.asyncio |
3 | 4 | import asyncio |
@@ -57,6 +58,7 @@ def __init__( |
57 | 58 | self.trust_remote_code = args.trust_remote_code |
58 | 59 | self.args = args |
59 | 60 | self.visual_model_rpc_ports = visual_model_rpc_ports |
| 61 | + self.send_batch_size = min(args.visual_send_batch_size, args.cache_capacity) |
60 | 62 | self.shm_req_manager = ShmReqManager() |
61 | 63 |
|
62 | 64 | async def wait_to_model_ready(self): |
@@ -117,6 +119,18 @@ async def loop_for_fwd(self): |
117 | 119 | else: |
118 | 120 | processing_group_reqs = [] |
119 | 121 | images_need_infer = [] |
| 122 | + ready_to_send = [] |
| 123 | + |
| 124 | + def flush_ready(force: bool = False): |
| 125 | + if not ready_to_send: |
| 126 | + return |
| 127 | + if not force and len(ready_to_send) < self.send_batch_size: |
| 128 | + return |
| 129 | + |
| 130 | + for group_req_indexes in ready_to_send: |
| 131 | + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) |
| 132 | + ready_to_send.clear() |
| 133 | + |
120 | 134 | while len(self.waiting_reqs) > 0: |
121 | 135 | group_req_indexes = self.waiting_reqs.pop(0) |
122 | 136 | shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) |
@@ -146,24 +160,25 @@ async def loop_for_fwd(self): |
146 | 160 | if len(images_need_infer) == self.infer_batch_size: |
147 | 161 | await self.infer_imgs(images_need_infer) |
148 | 162 | images_need_infer = [] |
149 | | - for _group_req_indexes in processing_group_reqs: |
150 | | - self.send_to_next_module.send_pyobj( |
151 | | - _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL |
152 | | - ) |
| 163 | + ready_to_send.extend(processing_group_reqs) |
153 | 164 | processing_group_reqs = [] |
| 165 | + flush_ready(force=False) |
154 | 166 |
|
155 | 167 | if len(images_need_infer) == 0: |
156 | | - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) |
| 168 | + ready_to_send.append(group_req_indexes) |
| 169 | + flush_ready(force=False) |
157 | 170 | else: |
158 | 171 | processing_group_reqs.append(group_req_indexes) |
159 | 172 |
|
160 | 173 | if len(images_need_infer) > 0: |
161 | 174 | await self.infer_imgs(images_need_infer) |
162 | | - for _group_req_indexes in processing_group_reqs: |
163 | | - self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) |
164 | | - processing_group_reqs = [] |
165 | 175 | images_need_infer = [] |
166 | 176 |
|
| 177 | + # 这些处理完 image 的 group 也 ready 了 |
| 178 | + ready_to_send.extend(processing_group_reqs) |
| 179 | + processing_group_reqs = [] |
| 180 | + flush_ready(force=True) |
| 181 | + |
167 | 182 | async def loop_for_netio_req(self): |
168 | 183 | if not hasattr(self, "visual_recv_max_count"): |
169 | 184 | self.visual_recv_max_count = 64 |
|
0 commit comments