Skip to content

Commit 6f01afd

Browse files
author
sangchengmeng
committed
add visual_send_bs args
1 parent 8fd19c2 commit 6f01afd

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def rescale_and_normalize(
162162

163163
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
164164
image_arr = np.asarray(image, dtype=np.uint8)
165-
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True)
165+
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous()
166166
grouped_images, grouped_images_index = group_images_by_shape(
167167
[image_data], disable_grouping=self.disable_grouping
168168
)

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
355355
parser.add_argument(
356356
"--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch"
357357
)
358+
parser.add_argument(
359+
"--visual_send_batch_size",
360+
type=int,
361+
default=1,
362+
help="number of images embedding to send to llm process in each batch",
363+
)
358364
parser.add_argument(
359365
"--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2"
360366
)

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class StartArgs:
7777
grouping_key: List[str] = field(default_factory=list)
7878
push_interval: int = field(default=10)
7979
visual_infer_batch_size: int = field(default=1)
80+
visual_send_batch_size: int = field(default=10)
8081
visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
8182
visual_tp: int = field(default=1)
8283
visual_dp: int = field(default=1)

lightllm/server/visualserver/manager.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import zmq
23
import zmq.asyncio
34
import asyncio
@@ -57,6 +58,7 @@ def __init__(
5758
self.trust_remote_code = args.trust_remote_code
5859
self.args = args
5960
self.visual_model_rpc_ports = visual_model_rpc_ports
61+
self.send_batch_size = min(args.visual_send_batch_size, args.cache_capacity)
6062
self.shm_req_manager = ShmReqManager()
6163

6264
async def wait_to_model_ready(self):
@@ -117,6 +119,18 @@ async def loop_for_fwd(self):
117119
else:
118120
processing_group_reqs = []
119121
images_need_infer = []
122+
ready_to_send = []
123+
124+
def flush_ready(force: bool = False):
125+
if not ready_to_send:
126+
return
127+
if not force and len(ready_to_send) < self.send_batch_size:
128+
return
129+
130+
for group_req_indexes in ready_to_send:
131+
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
132+
ready_to_send.clear()
133+
120134
while len(self.waiting_reqs) > 0:
121135
group_req_indexes = self.waiting_reqs.pop(0)
122136
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
@@ -146,24 +160,25 @@ async def loop_for_fwd(self):
146160
if len(images_need_infer) == self.infer_batch_size:
147161
await self.infer_imgs(images_need_infer)
148162
images_need_infer = []
149-
for _group_req_indexes in processing_group_reqs:
150-
self.send_to_next_module.send_pyobj(
151-
_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL
152-
)
163+
ready_to_send.extend(processing_group_reqs)
153164
processing_group_reqs = []
165+
flush_ready(force=False)
154166

155167
if len(images_need_infer) == 0:
156-
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
168+
ready_to_send.append(group_req_indexes)
169+
flush_ready(force=False)
157170
else:
158171
processing_group_reqs.append(group_req_indexes)
159172

160173
if len(images_need_infer) > 0:
161174
await self.infer_imgs(images_need_infer)
162-
for _group_req_indexes in processing_group_reqs:
163-
self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
164-
processing_group_reqs = []
165175
images_need_infer = []
166176

177+
# 这些处理完 image 的 group 也 ready 了
178+
ready_to_send.extend(processing_group_reqs)
179+
processing_group_reqs = []
180+
flush_ready(force=True)
181+
167182
async def loop_for_netio_req(self):
168183
if not hasattr(self, "visual_recv_max_count"):
169184
self.visual_recv_max_count = 64

0 commit comments

Comments
 (0)