Skip to content

Commit d77ba8e

Browse files
author
sangchengmeng
committed
add visual_send_bs args
1 parent 8fd19c2 commit d77ba8e

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def rescale_and_normalize(
162162

163163
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
164164
image_arr = np.asarray(image, dtype=np.uint8)
165-
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True)
165+
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous()
166166
grouped_images, grouped_images_index = group_images_by_shape(
167167
[image_data], disable_grouping=self.disable_grouping
168168
)

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
355355
parser.add_argument(
356356
"--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch"
357357
)
358+
parser.add_argument(
359+
"--visual_send_batch_size",
360+
type=int,
361+
default=1,
362+
help="number of images embedding to send to llm process in each batch",
363+
)
358364
parser.add_argument(
359365
"--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2"
360366
)

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class StartArgs:
7777
grouping_key: List[str] = field(default_factory=list)
7878
push_interval: int = field(default=10)
7979
visual_infer_batch_size: int = field(default=1)
80+
visual_send_batch_size: int = field(default=1)
8081
visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
8182
visual_tp: int = field(default=1)
8283
visual_dp: int = field(default=1)

lightllm/server/visualserver/manager.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self.trust_remote_code = args.trust_remote_code
5858
self.args = args
5959
self.visual_model_rpc_ports = visual_model_rpc_ports
60+
self.send_batch_size = min(args.visual_send_batch_size, args.cache_capacity)
6061
self.shm_req_manager = ShmReqManager()
6162

6263
async def wait_to_model_ready(self):
@@ -117,6 +118,18 @@ async def loop_for_fwd(self):
117118
else:
118119
processing_group_reqs = []
119120
images_need_infer = []
121+
ready_to_send = []
122+
123+
def flush_ready(force: bool = False):
124+
if not ready_to_send:
125+
return
126+
if not force and len(ready_to_send) < self.send_batch_size:
127+
return
128+
129+
for group_req_indexes in ready_to_send:
130+
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
131+
ready_to_send.clear()
132+
120133
while len(self.waiting_reqs) > 0:
121134
group_req_indexes = self.waiting_reqs.pop(0)
122135
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
@@ -146,24 +159,25 @@ async def loop_for_fwd(self):
146159
if len(images_need_infer) == self.infer_batch_size:
147160
await self.infer_imgs(images_need_infer)
148161
images_need_infer = []
149-
for _group_req_indexes in processing_group_reqs:
150-
self.send_to_next_module.send_pyobj(
151-
_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL
152-
)
162+
ready_to_send.extend(processing_group_reqs)
153163
processing_group_reqs = []
164+
flush_ready(force=False)
154165

155166
if len(images_need_infer) == 0:
156-
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
167+
ready_to_send.append(group_req_indexes)
168+
flush_ready(force=False)
157169
else:
158170
processing_group_reqs.append(group_req_indexes)
159171

160172
if len(images_need_infer) > 0:
161173
await self.infer_imgs(images_need_infer)
162-
for _group_req_indexes in processing_group_reqs:
163-
self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
164-
processing_group_reqs = []
165174
images_need_infer = []
166175

176+
# 这些处理完 image 的 group 也 ready 了
177+
ready_to_send.extend(processing_group_reqs)
178+
processing_group_reqs = []
179+
flush_ready(force=True)
180+
167181
async def loop_for_netio_req(self):
168182
if not hasattr(self, "visual_recv_max_count"):
169183
self.visual_recv_max_count = 64

0 commit comments

Comments
 (0)