Skip to content

Commit 27ef8f3

Browse files
committed
add vit mananger for vit-llm disaggr
1 parent d4de040 commit 27ef8f3

File tree

7 files changed

+168
-135
lines changed

7 files changed

+168
-135
lines changed

lightllm/server/api_cli.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,17 @@ def make_argument_parser() -> argparse.ArgumentParser:
506506
default=0.03,
507507
help="""The interval of the schedule time, default is 30ms.""",
508508
)
509+
parser.add_argument(
510+
"--enable_remote_vit",
511+
action="store_true",
512+
help="Whether to enable remote vit for multimodal service.",
513+
)
514+
parser.add_argument(
515+
"--remote_vit_port",
516+
type=int,
517+
default=12346,
518+
help="The port number for the remote vit service.",
519+
)
509520
# redis for vit llm disaggregation
510521
parser.add_argument(
511522
"--redis_port",

lightllm/server/api_server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,5 @@
1313
config_server_start(args)
1414
elif args.run_mode == "visual_only":
1515
visual_only_start(args)
16-
elif args.run_mode == "llm_only":
17-
llm_only_start(args)
1816
else:
1917
normal_or_p_d_start(args)

lightllm/server/api_start.py

Lines changed: 4 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import signal
77
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
8-
from lightllm.utils.start_utils import process_manager, kill_recursive
8+
from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode
99
from .metrics.manager import start_metric_manager
1010
from .embed_cache.manager import start_cache_manager
1111
from lightllm.utils.log_utils import init_logger
@@ -157,11 +157,13 @@ def check_and_set_args(args):
157157
assert args.mtp_draft_model_dir is None
158158
assert args.mtp_step == 0
159159

160+
args.enable_multimodal = is_multimodal_mode(args)
160161
# visual_only模式下才需要设置visual_embed_path
161162
if args.visual_embed_path is not None:
162163
assert (
163164
args.run_mode == "visual_only" or args.run_mode == "llm_only"
164165
), "only visual_only or llm_only mode need visual_embed_path"
166+
165167
# 检查GPU数量是否足够
166168
if args.visual_gpu_ids is None:
167169
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
@@ -174,13 +176,11 @@ def check_and_set_args(args):
174176
args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus]
175177

176178
# 检查visual_nccl_port数量是否足够
177-
if len(args.visual_nccl_ports) < args.visual_dp:
179+
if args.visual_nccl_ports is not None and len(args.visual_nccl_ports) < args.visual_dp:
178180
raise ValueError(
179181
f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, "
180182
f"but got ({len(args.visual_nccl_ports)})."
181183
)
182-
else:
183-
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
184184

185185
if args.visual_dp <= 0:
186186
raise ValueError("visual_dp must be a positive integer.")
@@ -287,7 +287,6 @@ def normal_or_p_d_start(args):
287287
logger.info(f"all start args:{args}")
288288

289289
ports_locker.release_port()
290-
291290
if args.enable_multimodal:
292291
from .visualserver.manager import start_visual_process
293292

@@ -381,105 +380,6 @@ def normal_or_p_d_start(args):
381380
return
382381

383382

384-
def llm_only_start(args):
385-
386-
check_and_set_args(args)
387-
already_uesd_ports = [args.nccl_port, args.port]
388-
389-
# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
390-
# 捕获到端口设置冲突的问题
391-
ports_locker = PortLocker(already_uesd_ports)
392-
ports_locker.lock_port()
393-
394-
node_world_size = args.tp // args.nnodes
395-
can_use_ports = alloc_can_use_network_port(num=4 + node_world_size, used_nccl_ports=already_uesd_ports)
396-
logger.info(f"alloced ports: {can_use_ports}")
397-
(
398-
router_port,
399-
detokenization_port,
400-
detokenization_pub_port,
401-
metric_port,
402-
) = can_use_ports[0:4]
403-
can_use_ports = can_use_ports[4:]
404-
405-
# 将申请好的端口放入args参数中
406-
args.router_port = router_port
407-
args.detokenization_port = detokenization_port
408-
args.detokenization_pub_port = detokenization_pub_port
409-
args.metric_port = metric_port
410-
411-
# 申请在 p d 分离模式下,会用的端口
412-
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
413-
# p d 分离模式下用于标识节点的id
414-
args.pd_node_id = uuid.uuid4().int
415-
# p 节点用来建立torch kv 传输分布组的可用端口范围
416-
args.pd_p_allowed_port_min = 20000
417-
args.pd_p_allowed_port_max = 30000
418-
419-
# p d 分离模式下,decode节点的调度间隙是0
420-
if args.run_mode == "decode":
421-
args.router_max_wait_tokens = 0
422-
423-
send_and_receive_node_ip(args) # 多机用于收发node ip
424-
set_env_start_args(args)
425-
logger.info(f"all start args:{args}")
426-
427-
ports_locker.release_port()
428-
429-
process_manager.start_submodule_processes(
430-
start_funcs=[
431-
start_metric_manager,
432-
],
433-
start_args=[(metric_port, args)],
434-
)
435-
436-
process_manager.start_submodule_processes(
437-
start_funcs=[start_router_process, start_detokenization_process],
438-
start_args=[
439-
(args, router_port, detokenization_port, metric_port),
440-
(args, detokenization_port, detokenization_pub_port),
441-
],
442-
)
443-
444-
# 启动 gunicorn
445-
command = [
446-
"gunicorn",
447-
"--workers",
448-
f"{args.httpserver_workers}",
449-
"--worker-class",
450-
"uvicorn.workers.UvicornWorker",
451-
"--bind",
452-
f"{args.host}:{args.port}",
453-
"--log-level",
454-
"info",
455-
"--access-logfile",
456-
"-",
457-
"--error-logfile",
458-
"-",
459-
"lightllm.server.api_http:app",
460-
"--timeout",
461-
f"{get_lightllm_gunicorn_time_out_seconds()}",
462-
"--keep-alive",
463-
f"{get_lightllm_gunicorn_keep_alive()}",
464-
]
465-
466-
# 启动子进程
467-
http_server_process = subprocess.Popen(command)
468-
469-
if "s3://" in args.model_dir:
470-
from lightllm.utils.petrel_helper import s3_model_clear
471-
472-
s3_model_clear(args.model_dir)
473-
474-
if args.health_monitor:
475-
from lightllm.server.health_monitor.manager import start_health_check_process
476-
477-
process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)])
478-
setup_signal_handlers(http_server_process, process_manager)
479-
http_server_process.wait()
480-
return
481-
482-
483383
def pd_master_start(args):
484384
set_unique_server_name(args)
485385
if args.run_mode != "pd_master":

lightllm/server/core/objs/io_objs/group_req.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
from ..req import Req
55

66

7-
@dataclass
8-
class VisualOnlyReqIndexes:
9-
group_req_id: int
10-
multimodal_params: MultimodalParams
11-
12-
137
@dataclass
148
class GroupReqIndexes:
159
group_req_id: int

lightllm/server/httpserver/manager.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def __init__(
8181
)
8282

8383
self.enable_multimodal = enable_multimodal
84-
if self.enable_multimodal and self.args.run_mode != "llm_only":
84+
if self.enable_multimodal:
8585
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
86-
self.send_to_visual = context.socket(zmq.PUSH)
87-
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
86+
# 初始化VIT连接管理器
87+
from .vit_loop import VITConnectionManager
88+
89+
self.vit_manager = VITConnectionManager(args, context, visual_port)
8890

8991
self.token_id_range_start = 100000000
9092
self.token_id_range_end = 2 ** 63 - 1
@@ -406,10 +408,7 @@ async def _encode(
406408
), "too many multimodal items!"
407409
if multimodal_params.audios:
408410
assert self.args.enable_multimodal_audio, "audio multimodal not enabled"
409-
if self.args.run_mode == "llm_only":
410-
await self._get_image_embedding_from_afs(multimodal_params, sampling_params)
411-
else:
412-
await self._alloc_multimodal_resources(multimodal_params, sampling_params)
411+
await self._alloc_multimodal_resources(multimodal_params, sampling_params)
413412
prompt_ids = self.tokenizer.encode(
414413
prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens
415414
)
@@ -483,9 +482,9 @@ async def transfer_to_next_module(
483482
group_req_objs: Optional[GroupReqObjs] = None,
484483
):
485484

486-
if self.pd_mode == NodeRole.P:
487-
if self.enable_multimodal and self.args.run_mode != "llm_only":
488-
self.send_to_visual.send_pyobj(
485+
if self.pd_mode.is_P_or_NORMAL():
486+
if self.enable_multimodal:
487+
await self.vit_manager.send_to_vit(
489488
group_req_objs.to_group_req_index(),
490489
protocol=pickle.HIGHEST_PROTOCOL,
491490
)
@@ -504,19 +503,6 @@ async def transfer_to_next_module(
504503
)
505504
return
506505

507-
if self.pd_mode == NodeRole.NORMAL or self.pd_mode == NodeRole.LLM_ONLY:
508-
if self.enable_multimodal and self.args.run_mode != "llm_only":
509-
self.send_to_visual.send_pyobj(
510-
group_req_objs.to_group_req_index(),
511-
protocol=pickle.HIGHEST_PROTOCOL,
512-
)
513-
else:
514-
self.send_to_router.send_pyobj(
515-
group_req_objs.to_group_req_index(),
516-
protocol=pickle.HIGHEST_PROTOCOL,
517-
)
518-
return
519-
520506
assert False, "dead code path"
521507
return
522508

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import asyncio
2+
import zmq
3+
import zmq.asyncio
4+
import time
5+
import pickle
6+
from typing import Dict, List, Optional, Any
7+
from lightllm.utils.log_utils import init_logger
8+
import httpx
9+
import base64
10+
from dataclasses import dataclass
11+
12+
logger = init_logger(__name__)
13+
14+
15+
@dataclass
16+
class VIT_Obj:
17+
node_id: int
18+
host_ip_port: str
19+
20+
def to_log_str(self):
21+
return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}"
22+
23+
24+
class VITConnectionManager:
25+
"""VIT连接管理器"""
26+
27+
def __init__(self, args, context, local_visual_port: int):
28+
self.args = args
29+
self.context = context
30+
self.local_visual_port = local_visual_port
31+
32+
self.send_to_visual = None
33+
self.remote_vit_instances = []
34+
self.current_vit_index = 0
35+
self.remote_vit = args.enable_remote_vit
36+
self.remote_vit_port = args.remote_vit_port
37+
38+
self._setup_vit_connections()
39+
40+
def _setup_vit_connections(self):
41+
"""
42+
设置VIT连接,支持本地和远程VIT实例
43+
支持多种连接模式:
44+
1. 本地VIT实例 (默认)
45+
2. 远程单个VIT实例
46+
3. 远程多个VIT实例 (负载均衡)
47+
"""
48+
if self.remote_vit:
49+
# 远程VIT实例模式
50+
self._setup_remote_vit_connections()
51+
else:
52+
self._setup_local_vit_connection()
53+
54+
def _setup_local_vit_connection(self):
55+
self.send_to_visual = self.context.socket(zmq.PUSH)
56+
self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}")
57+
logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}")
58+
59+
def _setup_remote_vit_connections(self):
60+
asyncio.create_task(self.vit_handle_loop())
61+
62+
# wait for remote vit instances
63+
while True:
64+
if len(self.remote_vit_instances) > 0:
65+
break
66+
time.sleep(1)
67+
68+
def _get_vit_instance(self):
69+
"""
70+
获取下一个可用的VIT实例 (轮询负载均衡)
71+
"""
72+
if not self.remote_vit:
73+
return self.send_to_visual
74+
75+
# 简单的轮询负载均衡
76+
index = (self.current_vit_index + 1) % len(self.remote_vit_instances)
77+
self.current_vit_index = index
78+
return self.remote_vit_instances[index]
79+
80+
async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL):
81+
"""
82+
发送数据到VIT实例,支持本地和远程模式
83+
"""
84+
instance = self._get_vit_instance()
85+
try:
86+
instance.send_pyobj(data, protocol=protocol)
87+
except Exception as e:
88+
logger.error(f"Failed to send to VIT instance {instance.host_ip_port}: {e}")
89+
raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}")
90+
91+
async def vit_handle_loop(self):
92+
while True:
93+
try:
94+
id_to_vit_obj = await self._get_vit_objs()
95+
logger.info(f"get vit_objs {id_to_vit_obj}")
96+
for id, remote_instance in self.remote_vit_instances.items():
97+
if id not in id_to_vit_obj:
98+
try:
99+
remote_instance[id].close()
100+
except:
101+
pass
102+
self.remote_vit_instances.pop(id)
103+
logger.info(f"remote vit {id} closed")
104+
105+
for id, vit_obj in id_to_vit_obj.items():
106+
if id not in self.remote_vit_instances:
107+
self.remote_vit_instances[id] = self.context.socket(zmq.PUSH)
108+
self.remote_vit_instances[id].connect(
109+
f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}"
110+
)
111+
await asyncio.sleep(30)
112+
except Exception as e:
113+
logger.exception(str(e))
114+
await asyncio.sleep(10)
115+
116+
async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]:
117+
"""
118+
get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。
119+
"""
120+
# 使用 config_server 服务来发现所有的 pd_master 节点。
121+
uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit"
122+
123+
try:
124+
async with httpx.AsyncClient() as client:
125+
response = await client.get(uri)
126+
if response.status_code == 200:
127+
base64data = response.json()["data"]
128+
id_to_vit_obj = pickle.loads(base64.b64decode(base64data))
129+
return id_to_vit_obj
130+
else:
131+
logger.error(f"get pd_master_objs error {response.status_code}")
132+
return None
133+
except Exception as e:
134+
logger.exception(str(e))
135+
await asyncio.sleep(10)
136+
return None

0 commit comments

Comments
 (0)