Skip to content

Commit 3a89cf0

Browse files
committed
add visual start
1 parent ded28b7 commit 3a89cf0

File tree

10 files changed

+55
-154
lines changed

10 files changed

+55
-154
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
77
parser.add_argument(
88
"--run_mode",
99
type=str,
10-
choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"],
10+
choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual"],
1111
default="normal",
1212
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
1313
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
@@ -529,6 +529,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
529529
default=6379,
530530
help="The port number for the redis service in config_server mode.",
531531
)
532+
parser.add_argument(
533+
"--redis_evict_fraction",
534+
type=float,
535+
default=0.3,
536+
help="The evict fraction for the redis service in config_server mode.",
537+
)
532538
parser.add_argument(
533539
"--start_redis",
534540
action="store_true",

lightllm/server/api_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def set_args(self, args):
9393
args,
9494
metric_port=args.metric_port,
9595
)
96-
elif args.run_mode == "visual_only":
96+
elif args.run_mode == "visual":
9797
self.metric_client = MetricClient(args.metric_port)
9898
elif args.run_mode == "llm_only":
9999
init_tokenizer(args) # for openai api
@@ -160,7 +160,7 @@ def get_model_name():
160160
@app.get("/health", summary="Check server health")
161161
@app.head("/health", summary="Check server health")
162162
async def healthcheck(request: Request):
163-
if g_objs.args.run_mode in ["pd_master", "visual_only"]:
163+
if g_objs.args.run_mode in ["pd_master", "visual"]:
164164
return JSONResponse({"message": "Ok"}, status_code=200)
165165

166166
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
@@ -367,7 +367,7 @@ async def startup_event():
367367
logger.info("server start up")
368368
loop = asyncio.get_event_loop()
369369
g_objs.set_args(get_env_start_args())
370-
if g_objs.args.run_mode != "visual_only":
370+
if g_objs.args.run_mode != "visual":
371371
loop.create_task(g_objs.httpserver_manager.handle_loop())
372372
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
373373
return

lightllm/server/api_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
66
parser = make_argument_parser()
77
args = parser.parse_args()
8-
from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start, llm_only_start
8+
from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start
99

1010
if args.run_mode == "pd_master":
1111
pd_master_start(args)
1212
elif args.run_mode == "config_server":
1313
config_server_start(args)
14-
elif args.run_mode == "visual_only":
15-
visual_only_start(args)
14+
elif args.run_mode == "visual":
15+
visual_start(args)
1616
else:
1717
normal_or_p_d_start(args)

lightllm/server/api_start.py

Lines changed: 16 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def signal_handler(sig, frame):
5757
signal.signal(signal.SIGINT, signal_handler)
5858

5959
logger.info(f"start process pid {os.getpid()}")
60-
logger.info(f"http server pid {http_server_process.pid}")
60+
if http_server_process:
61+
logger.info(f"http server pid {http_server_process.pid}")
6162
return
6263

6364

@@ -72,7 +73,7 @@ def check_and_set_args(args):
7273

7374
enable_mps()
7475

75-
if args.run_mode not in ["normal", "prefill", "decode", "llm_only", "visual_only"]:
76+
if args.run_mode not in ["normal", "prefill", "decode", "llm_only", "visual"]:
7677
return
7778

7879
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
@@ -420,11 +421,9 @@ def pd_master_start(args):
420421
http_server_process.wait()
421422

422423

423-
def visual_only_start(args):
424+
def visual_start(args):
424425
check_and_set_args(args)
425-
if args.run_mode != "visual_only":
426-
return
427-
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port]
426+
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.remote_vit_port]
428427
can_use_ports = alloc_can_use_network_port(
429428
num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
430429
)
@@ -437,6 +436,7 @@ def visual_only_start(args):
437436
metric_port,
438437
) = can_use_ports[0:5]
439438
can_use_ports = can_use_ports[5:]
439+
print(cache_port)
440440

441441
visual_model_tp_ports = []
442442
for _ in range(args.visual_dp):
@@ -456,13 +456,6 @@ def visual_only_start(args):
456456

457457
set_env_start_args(args)
458458

459-
process_manager.start_submodule_processes(
460-
start_funcs=[
461-
start_metric_manager,
462-
],
463-
start_args=[(metric_port, args)],
464-
)
465-
466459
from .visualserver.manager import start_visual_process
467460

468461
process_manager.start_submodule_processes(
@@ -476,58 +469,18 @@ def visual_only_start(args):
476469
start_visual_process,
477470
],
478471
start_args=[
479-
(args, audio_port, visual_port, cache_port, visual_model_tp_ports),
472+
(args, router_port, visual_port, cache_port, visual_model_tp_ports),
480473
],
481474
)
482-
if args.enable_multimodal_audio:
483-
from .audioserver.manager import start_audio_process
484-
485-
process_manager.start_submodule_processes(
486-
start_funcs=[
487-
start_audio_process,
488-
],
489-
start_args=[
490-
(args, router_port, audio_port, cache_port),
491-
],
492-
)
493-
494-
# 启动 gunicorn
495-
command = [
496-
"gunicorn",
497-
"--workers",
498-
f"{args.httpserver_workers}",
499-
"--worker-class",
500-
"uvicorn.workers.UvicornWorker",
501-
"--bind",
502-
f"{args.host}:{args.port}",
503-
"--log-level",
504-
"info",
505-
"--access-logfile",
506-
"-",
507-
"--error-logfile",
508-
"-",
509-
"lightllm.server.api_http:app",
510-
"--timeout",
511-
f"{get_lightllm_gunicorn_time_out_seconds()}",
512-
"--keep-alive",
513-
f"{get_lightllm_gunicorn_keep_alive()}",
514-
]
515-
516-
# 启动子进程
517-
http_server_process = subprocess.Popen(command)
518-
519-
if "s3://" in args.model_dir:
520-
from lightllm.utils.petrel_helper import s3_model_clear
521-
522-
s3_model_clear(args.model_dir)
523-
524-
if args.health_monitor:
525-
from lightllm.server.health_monitor.manager import start_health_check_process
526-
527-
process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)])
528-
setup_signal_handlers(http_server_process, process_manager)
529-
http_server_process.wait()
530-
return
475+
setup_signal_handlers(None, process_manager)
476+
try:
477+
while True:
478+
time.sleep(1)
479+
except KeyboardInterrupt:
480+
logger.info("Received keyboard interrupt, shutting down...")
481+
process_manager.terminate_all_processes()
482+
logger.info("All processes have been terminated gracefully.")
483+
sys.exit(0)
531484

532485

533486
def config_server_start(args):

lightllm/server/core/objs/req.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -161,30 +161,6 @@ def init(
161161

162162
self.post_init()
163163

164-
def init_visual_only(
165-
self,
166-
request_id: int,
167-
):
168-
# 只是为了有更好的编码辅助类型提示
169-
self.index_in_shm_mem: int = self.index_in_shm_mem
170-
self.ref_count: int = self.ref_count
171-
172-
self.request_id = request_id
173-
self.group_req_id = convert_sub_id_to_group_id(request_id)
174-
self.is_paused = False
175-
self.finish_status = FinishStatus()
176-
self.is_aborted = False
177-
self.router_aborted = False
178-
self.shm_infer_released = False
179-
self.shm_cur_kv_len = 0
180-
self.shm_cur_output_len = 0
181-
self.candetoken_out_len = 0
182-
self.prompt_cache_len = 0
183-
self.finish_token_index = -1
184-
self.can_released_mark = False
185-
186-
self.post_init()
187-
188164
def post_init(self):
189165
# 子类继承进行一些额外的初始化操作
190166
pass

lightllm/server/embed_cache/impl/memory_cache_with_redis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, args) -> None:
2121
self.redis_cache = EmbedRefCountRedis(
2222
redis_url=redis_url,
2323
capacity=args.cache_capacity,
24-
evict_fraction=args.evict_fraction,
24+
evict_fraction=args.redis_evict_fraction,
2525
image_embed_dir=args.image_embed_dir,
2626
)
2727
# 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id

lightllm/server/embed_cache/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
5555

5656

5757
def get_cache_manager(args):
58-
if args.enable_remote_vit:
58+
if args.enable_remote_vit or args.run_mode == "visual":
5959
return MemoryCacheWithRedis(args)
6060
else:
6161
return InMemoryCache(args)

lightllm/server/httpserver/manager.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
1515
from typing import Union, List, Tuple, Dict, Optional
16-
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes
16+
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
1717
from fastapi import Request
1818
from ..tokenizer import get_tokenizer
1919
from ..pd_io_struct import NodeRole
@@ -143,7 +143,7 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas):
143143
uid_list.append(rec["id"])
144144

145145
# If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server
146-
if self.args.run_mode == "llm_only":
146+
if self.enable_remote_vit:
147147
return
148148

149149
ready_flags = obtain(self.cache_client.root.get_items_data(uid_list))
@@ -304,41 +304,6 @@ async def _initialize_multimodal_metadata(
304304
img.uuid = int(md5sum, 16)
305305
img.token_num = token_num
306306

307-
# async def get_image_embeding(
308-
# self,
309-
# sampling_params: SamplingParams,
310-
# multimodal_params: MultimodalParams,
311-
# request: Request,
312-
# is_health_req: bool = False,
313-
# ) -> Tuple[int, str, dict, FinishStatus]:
314-
315-
# request_headers = request.headers if request is not None else {}
316-
# group_request_id = self.alloc_req_id(sampling_params, is_health_req)
317-
318-
# try:
319-
# await multimodal_params.verify_and_preload(request)
320-
# image_count = len(multimodal_params.images)
321-
# # 记录请求到达的相关信息
322-
# await self._log_req_header_for_visual_only(request_headers, group_request_id, image_count)
323-
# assert (
324-
# len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity
325-
# ), "too many multimodal items!"
326-
327-
# await self._initialize_multimodal_metadata(multimodal_params, sampling_params)
328-
329-
# visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, multimodal_params=multimodal_params)
330-
331-
# self.send_to_visual.send_pyobj(
332-
# visual_req_status,
333-
# protocol=pickle.HIGHEST_PROTOCOL,
334-
# )
335-
336-
# except Exception as e:
337-
# logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")
338-
# await self.abort(group_request_id, multimodal_params)
339-
# raise e
340-
# return
341-
342307
async def generate(
343308
self,
344309
prompt: Union[str, List[int]],

lightllm/server/visualserver/manager.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
from fastapi import Request
1212
from ..tokenizer import get_tokenizer
13-
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes
13+
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
1414
from lightllm.server.core.objs import ShmReqManager
1515
from lightllm.server.core.objs import SamplingParams
1616
from lightllm.server.core.objs import Req, FinishStatus
@@ -41,9 +41,8 @@ def __init__(
4141
visual_model_rpc_ports,
4242
):
4343
self.args = args
44-
self.remote_vit = args.enable_remote_vit
44+
self.remote_vit = args.enable_remote_vit or args.run_mode == "visual"
4545
self.cache_port = cache_port
46-
self.memory_cache = MemoryCacheWithRedis(args)
4746
self.waiting_reqs: List[GroupReqIndexes] = []
4847
self.infer_batch_size = args.visual_infer_batch_size
4948
self.trust_remote_code = args.trust_remote_code
@@ -53,29 +52,33 @@ def __init__(
5352
def _setup_connections(self):
5453
context = zmq.Context(2)
5554
if self.remote_vit:
56-
self.recv_from_httpserver.bind(f"tcp://*:{self.args.remote_vit_port}")
55+
self.recv_from_remote_llm = context.socket(zmq.PULL)
56+
self.recv_from_remote_llm.bind(f"tcp://*:{self.args.remote_vit_port}")
5757
else:
58+
self.recv_from_httpserver = context.socket(zmq.PULL)
5859
self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}")
5960
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
6061
self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}")
6162
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
6263

6364
async def wait_to_model_ready(self):
6465
# 待完成,需要读取config_server来起多个vit
65-
self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)]
66+
visual_dp = self.args.visual_dp
67+
visual_tp = self.args.visual_tp
68+
self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(visual_dp)]
6669

67-
for dp_rank_id in range(self.args.visual_dp):
70+
for dp_rank_id in range(visual_dp):
6871
tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id]
69-
for tp_rank_id in range(self.args.visual_tp):
70-
device_id = self.args.visual_gpu_ids[dp_rank_id * self.args.visual_tp + tp_rank_id]
72+
for tp_rank_id in range(visual_tp):
73+
device_id = self.args.visual_gpu_ids[dp_rank_id * visual_tp + tp_rank_id]
7174
rpc_model = await start_model_process(
72-
port=tp_ports_each_dp[tp_rank_id], vit_tp=self.args.visual_tp, device_id=device_id
75+
port=tp_ports_each_dp[tp_rank_id], vit_tp=visual_tp, device_id=device_id
7376
)
7477
self.model_rpcs[dp_rank_id].append(rpc_model)
7578

7679
init_model_ret = []
77-
for dp_rank_id in range(self.args.visual_dp): # async init model process
78-
for tp_rank_id in range(self.args.visual_tp):
80+
for dp_rank_id in range(visual_dp): # async init model process
81+
for tp_rank_id in range(visual_tp):
7982
kvargs = {
8083
"tp_rank_id": tp_rank_id,
8184
"dp_rank_id": dp_rank_id,

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,18 @@ def exposed_init_model(self, kvargs):
4040

4141
self.args = get_env_start_args()
4242

43-
weight_dir = (self.args.model_dir,)
44-
cache_port = (self.args.cache_port,)
45-
data_type = (self.args.data_type,)
46-
quant_type = (self.args.vit_quant_type,)
47-
quant_cfg = (self.args.vit_quant_cfg,)
48-
max_batch_size = (min(self.args.visual_infer_batch_size // self.args.visual_dp, 1),)
43+
weight_dir = self.args.model_dir
44+
cache_port = self.args.cache_port
45+
data_type = self.args.data_type
46+
quant_type = self.args.vit_quant_type
47+
quant_cfg = self.args.vit_quant_cfg
48+
max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1)
4949

5050
self.dp_rank_id = kvargs["dp_rank_id"]
5151
self.tp_rank_id = kvargs["tp_rank_id"]
5252
kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id
53-
54-
if self.args.run_mode != "visual_only":
55-
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
56-
self.visual_only = True if self.args.run_mode == "visual_only" else False
53+
print(cache_port)
54+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
5755

5856
init_vision_distributed_env(kvargs)
5957
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)

0 commit comments

Comments
 (0)