Skip to content

Commit ded28b7

Browse files
committed
update visual server mananger
1 parent 70bc956 commit ded28b7

File tree

6 files changed

+97
-191
lines changed

6 files changed

+97
-191
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
77
parser.add_argument(
88
"--run_mode",
99
type=str,
10-
choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only", "llm_only"],
10+
choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"],
1111
default="normal",
1212
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
1313
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
@@ -337,8 +337,6 @@ def make_argument_parser() -> argparse.ArgumentParser:
337337
)
338338
parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics")
339339
parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor")
340-
parser.add_argument("--visual_embed_path", type=str, default=None, help="path for vit embed")
341-
parser.add_argument("--visual_only_port", type=int, default=18097, help="port for visual only server")
342340
parser.add_argument(
343341
"--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value"
344342
)
@@ -507,6 +505,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
507505
default=0.03,
508506
help="""The interval of the schedule time, default is 30ms.""",
509507
)
508+
parser.add_argument(
509+
"--image_embed_dir",
510+
type=str,
511+
default=None,
512+
help="path for vit embed",
513+
)
510514
parser.add_argument(
511515
"--enable_remote_vit",
512516
action="store_true",

lightllm/server/api_start.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,6 @@ def check_and_set_args(args):
141141
assert args.mtp_step == 0
142142

143143
args.enable_multimodal = is_multimodal_mode(args)
144-
# visual_only模式下才需要设置visual_embed_path
145-
if args.visual_only_port is not None:
146-
assert (
147-
args.run_mode == "visual_only" or args.run_mode == "llm_only"
148-
), "only visual_only or llm_only mode need visual_only_port"
149-
150144
# 检查GPU数量是否足够
151145
if args.visual_gpu_ids is None:
152146
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
@@ -279,7 +273,7 @@ def normal_or_p_d_start(args):
279273
],
280274
start_args=[(cache_port, args)],
281275
)
282-
if args.enable_multimodal_audio and args.run_mode != "llm_only":
276+
if args.enable_multimodal_audio and not args.enable_remote_vit:
283277
from .audioserver.manager import start_audio_process
284278

285279
process_manager.start_submodule_processes(
@@ -299,7 +293,7 @@ def normal_or_p_d_start(args):
299293
],
300294
)
301295

302-
elif args.run_mode != "llm_only":
296+
elif not args.enable_remote_vit:
303297
process_manager.start_submodule_processes(
304298
start_funcs=[
305299
start_visual_process,

lightllm/server/embed_cache/manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Union, Optional
55
from lightllm.utils.graceful_utils import graceful_registry
66
from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache
7+
from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis
78
from rpyc.utils.classic import obtain
89

910

@@ -53,11 +54,18 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
5354
return self._impl.get_items_embed(ids)
5455

5556

57+
def get_cache_manager(args):
58+
if args.enable_remote_vit:
59+
return MemoryCacheWithRedis(args)
60+
else:
61+
return InMemoryCache(args)
62+
63+
5664
def start_cache_manager(port: int, args, pipe_writer):
5765
# 注册graceful 退出的处理
5866
graceful_registry(inspect.currentframe().f_code.co_name)
5967

60-
manager = InMemoryCache(args)
68+
manager = get_cache_manager(args)
6169
service = CacheServer(manager)
6270
from rpyc.utils.server import ThreadedServer
6371

lightllm/server/visualserver/manager.py

Lines changed: 42 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -41,63 +41,44 @@ def __init__(
4141
visual_model_rpc_ports,
4242
):
4343
self.args = args
44-
self.visual_only = True if self.args.run_mode == "visual_only" else False
45-
context = zmq.Context(2)
46-
self.id_gen = ReqIDGenerator()
47-
self.recv_from_httpserver = context.socket(zmq.PULL)
48-
if self.visual_only:
49-
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{self.args.visual_only_port}")
50-
else:
51-
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}")
52-
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
53-
self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}")
54-
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
55-
44+
self.remote_vit = args.enable_remote_vit
5645
self.cache_port = cache_port
5746
self.memory_cache = MemoryCacheWithRedis(args)
58-
self.waiting_reqs_from_httpserver: List[GroupReqIndexes] = []
59-
self.waiting_reqs_visual_only: List[VisualOnlyReqIndexes] = []
60-
self.model_weightdir = args.model_dir
61-
self.tp_world_size = args.tp
62-
self.vit_dp = args.visual_dp
63-
self.vit_tp = args.visual_tp
47+
self.waiting_reqs: List[GroupReqIndexes] = []
6448
self.infer_batch_size = args.visual_infer_batch_size
6549
self.trust_remote_code = args.trust_remote_code
6650
self.visual_model_rpc_ports = visual_model_rpc_ports
67-
self.shm_req_manager = ShmReqManager()
68-
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
51+
self._setup_connections()
52+
53+
def _setup_connections(self):
54+
context = zmq.Context(2)
55+
if self.remote_vit:
56+
self.recv_from_httpserver.bind(f"tcp://*:{self.args.remote_vit_port}")
57+
else:
58+
self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}")
59+
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
60+
self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}")
61+
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
6962

7063
async def wait_to_model_ready(self):
7164
# 待完成,需要读取config_server来起多个vit
7265
self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)]
7366

74-
for dp_rank_id in range(self.vit_dp):
67+
for dp_rank_id in range(self.args.visual_dp):
7568
tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id]
76-
for tp_rank_id in range(self.vit_tp):
77-
device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id]
69+
for tp_rank_id in range(self.args.visual_tp):
70+
device_id = self.args.visual_gpu_ids[dp_rank_id * self.args.visual_tp + tp_rank_id]
7871
rpc_model = await start_model_process(
79-
port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id
72+
port=tp_ports_each_dp[tp_rank_id], vit_tp=self.args.visual_tp, device_id=device_id
8073
)
8174
self.model_rpcs[dp_rank_id].append(rpc_model)
8275

8376
init_model_ret = []
84-
for dp_rank_id in range(self.vit_dp): # async init model process
85-
for tp_rank_id in range(self.vit_tp):
77+
for dp_rank_id in range(self.args.visual_dp): # async init model process
78+
for tp_rank_id in range(self.args.visual_tp):
8679
kvargs = {
87-
"weight_dir": self.model_weightdir,
88-
"trust_remote_code": self.trust_remote_code,
89-
"vit_dp": self.vit_dp,
90-
"vit_tp": self.vit_tp,
91-
"cache_port": self.cache_port,
9280
"tp_rank_id": tp_rank_id,
9381
"dp_rank_id": dp_rank_id,
94-
"vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id,
95-
"data_type": self.args.data_type,
96-
"visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id],
97-
"visual_gpu_ids": self.args.visual_gpu_ids,
98-
"quant_type": self.args.vit_quant_type,
99-
"quant_cfg": self.args.vit_quant_cfg,
100-
"max_batch_size": min(self.infer_batch_size // self.vit_dp, 1),
10182
}
10283
init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs))
10384
await asyncio.gather(*init_model_ret)
@@ -108,10 +89,10 @@ async def infer_imgs(self, images: List[ImageItem]):
10889
return
10990

11091
tasks = []
111-
for vit_dp_rank in range(self.vit_dp):
112-
assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)]
92+
for vit_dp_rank in range(self.args.visual_dp):
93+
assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.args.visual_dp)]
11394
if assigned_images:
114-
for vit_tp_rank in range(self.vit_tp):
95+
for vit_tp_rank in range(self.args.visual_tp):
11596
task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images))
11697
tasks.append(task)
11798

@@ -120,13 +101,13 @@ async def infer_imgs(self, images: List[ImageItem]):
120101

121102
async def loop_for_fwd(self):
122103
while True:
123-
if len(self.waiting_reqs_from_httpserver) == 0:
104+
if len(self.waiting_reqs) == 0:
124105
await asyncio.sleep(0.01) # 10ms
125106
else:
126107
processing_group_reqs = []
127108
images_need_infer = []
128-
while len(self.waiting_reqs_from_httpserver) > 0:
129-
group_req_indexes = self.waiting_reqs_from_httpserver.pop(0)
109+
while len(self.waiting_reqs) > 0:
110+
group_req_indexes = self.waiting_reqs.pop(0)
130111
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
131112
is_aborted = shm_req.is_aborted
132113
self.shm_req_manager.put_back_req_obj(shm_req)
@@ -167,16 +148,31 @@ async def loop_for_fwd(self):
167148
processing_group_reqs = []
168149
images_need_infer = []
169150

151+
def _recv_reqs(self):
152+
if self.remote_vit:
153+
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
154+
for img in recv_req.multimodal_params.images:
155+
data = img._preload_data
156+
img._preload_data = None
157+
md5sum = hashlib.md5(data).hexdigest()
158+
uid = int(md5sum, 16)
159+
# create_shm(get_shm_name_data(uid), data)
160+
self.cache_client.root.set_items_data([uid])
161+
162+
return recv_req
163+
else:
164+
return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
165+
170166
async def loop_for_netio_req(self):
171167
if not hasattr(self, "visual_recv_max_count"):
172168
self.visual_recv_max_count = 64
173169

174170
while True:
175171
try:
176172
for _ in range(self.visual_recv_max_count):
177-
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
173+
recv_req: GroupReqIndexes = self._recv_reqs()
178174
if isinstance(recv_req, GroupReqIndexes):
179-
self.waiting_reqs_from_httpserver.append(recv_req)
175+
self.waiting_reqs.append(recv_req)
180176
else:
181177
assert False, f"Error Req Inf {recv_req}"
182178
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)
@@ -211,103 +207,6 @@ async def loop_for_fwd_visual_only(self):
211207
# 在这里release这个image,ref-1
212208
logger.info(f"req-id {visual_req.group_req_id} has been release ok")
213209

214-
async def _initialize_multimodal_metadata(
215-
self, multimodal_params: MultimodalParams, sampling_params: SamplingParams
216-
):
217-
for img in multimodal_params.images:
218-
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params)
219-
data = img.read()
220-
# must after init_imageitem_extral_params
221-
token_num = self.tokenizer.get_image_token_length(img)
222-
md5sum = "{}_{}".format(
223-
hashlib.md5(data).hexdigest(),
224-
hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(),
225-
)
226-
img.uuid = int(md5sum, 16)
227-
img.token_num = token_num
228-
229-
async def _log_req_header(self, request_headers, group_request_id: int, image_count: int):
230-
231-
x_request_id = request_headers.get("X-Request-Id", "")
232-
x_session_id = request_headers.get("X-Session-Id", "")
233-
234-
format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
235-
logger.info(
236-
f"recieved req X-Request-Id:{x_request_id} "
237-
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
238-
f"lightllm_req_id:{group_request_id} "
239-
f"image_count:{image_count}"
240-
)
241-
return
242-
243-
def alloc_req_id(self, sampling_params, is_health_req: bool = False):
244-
# 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性
245-
# 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置
246-
# health 请求 request_id 为负数,直接返回
247-
if is_health_req:
248-
return sampling_params.group_request_id
249-
group_request_id = self.id_gen.generate_id()
250-
251-
sampling_params.group_request_id = group_request_id
252-
return group_request_id
253-
254-
# async def generate(
255-
# self,
256-
# sampling_params: SamplingParams,
257-
# multimodal_params: MultimodalParams,
258-
# request: Request,
259-
# is_health_req: bool = False,
260-
# ) -> Tuple[int, str, dict, FinishStatus]:
261-
262-
# request_headers = request.headers if request is not None else {}
263-
# group_request_id = self.alloc_req_id(sampling_params, is_health_req)
264-
265-
# try:
266-
# await multimodal_params.verify_and_preload(request)
267-
# image_count = len(multimodal_params.images)
268-
# # 记录请求到达的相关信息
269-
# await self._log_req_header(request_headers, group_request_id, image_count)
270-
# assert (
271-
# len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
272-
# ), "too many multimodal items!"
273-
274-
# await self._initialize_multimodal_metadata(multimodal_params, sampling_params)
275-
276-
# visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id,
277-
# multimodal_params=multimodal_params)
278-
# self.waiting_reqs_visual_only.append(visual_req_status)
279-
280-
# except Exception as e:
281-
# logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")
282-
# await self.abort(group_request_id, multimodal_params)
283-
# raise e
284-
# return
285-
286-
async def abort(self, group_req_id: int, multimodal_params: MultimodalParams):
287-
logger.warning(f"aborted group_request_id {group_req_id}")
288-
for img in multimodal_params.images:
289-
img.is_abort = True
290-
return
291-
292-
async def loop_for_netio_req(self):
293-
if not hasattr(self, "visual_recv_max_count"):
294-
self.visual_recv_max_count = 64
295-
296-
while True:
297-
try:
298-
for _ in range(self.visual_recv_max_count):
299-
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
300-
print(f"recv_req is {recv_req}")
301-
if isinstance(recv_req, GroupReqIndexes):
302-
self.waiting_reqs_from_httpserver.append(recv_req)
303-
else:
304-
assert False, f"Error Req Inf {recv_req}"
305-
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)
306-
except zmq.ZMQError:
307-
# 当队列已经开始清空的时候,将一次接受数量下调
308-
self.visual_recv_max_count = 64
309-
await asyncio.sleep(0.01)
310-
311210
def clean_up(self):
312211
for model_rpc in self.model_rpcs:
313212
model_rpc.rpc_server_process.kill()

0 commit comments

Comments
 (0)