Skip to content

Commit c99bb46

Browse files
committed
fix vit manager
1 parent a566580 commit c99bb46

File tree

9 files changed

+39
-54
lines changed

9 files changed

+39
-54
lines changed

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
353353
"--visual_nccl_ports",
354354
nargs="+",
355355
type=int,
356-
default=[29500],
356+
default=None,
357357
help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502",
358358
)
359359
parser.add_argument(

lightllm/server/api_http.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,6 @@ def set_args(self, args):
9595
)
9696
elif args.run_mode == "visual":
9797
self.metric_client = MetricClient(args.metric_port)
98-
elif args.run_mode == "llm_only":
99-
init_tokenizer(args) # for openai api
100-
SamplingParams.load_generation_cfg(args.model_dir)
101-
self.metric_client = MetricClient(args.metric_port)
102-
self.httpserver_manager = HttpServerManager(
103-
args,
104-
router_port=args.router_port,
105-
cache_port=None,
106-
detokenization_pub_port=args.detokenization_pub_port,
107-
visual_port=None,
108-
enable_multimodal=args.enable_multimodal,
109-
metric_port=args.metric_port,
110-
)
111-
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
112-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
11398
else:
11499
init_tokenizer(args) # for openai api
115100
SamplingParams.load_generation_cfg(args.model_dir)
@@ -365,9 +350,10 @@ async def shutdown():
365350
@app.on_event("startup")
366351
async def startup_event():
367352
logger.info("server start up")
353+
if g_objs.httpserver_manager is None:
354+
return
368355
loop = asyncio.get_event_loop()
369356
g_objs.set_args(get_env_start_args())
370-
if g_objs.args.run_mode != "visual":
371-
loop.create_task(g_objs.httpserver_manager.handle_loop())
357+
loop.create_task(g_objs.httpserver_manager.handle_loop())
372358
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
373359
return

lightllm/server/api_start.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ def check_and_set_args(args):
208208
def normal_or_p_d_start(args):
209209

210210
check_and_set_args(args)
211-
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port]
211+
already_uesd_ports = [args.nccl_port, args.port]
212212
if args.run_mode == "decode":
213-
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port]
213+
already_uesd_ports = [args.nccl_port, args.port, args.pd_decode_rpyc_port]
214214

215215
# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
216216
# 捕获到端口设置冲突的问题
@@ -219,7 +219,7 @@ def normal_or_p_d_start(args):
219219

220220
node_world_size = args.tp // args.nnodes
221221
can_use_ports = alloc_can_use_network_port(
222-
num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
222+
num=7 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, used_nccl_ports=already_uesd_ports
223223
)
224224
logger.info(f"alloced ports: {can_use_ports}")
225225
(
@@ -239,6 +239,9 @@ def normal_or_p_d_start(args):
239239
can_use_ports = can_use_ports[args.visual_tp :]
240240
visual_model_tp_ports.append(tp_ports_for_dp)
241241

242+
args.visual_nccl_ports = can_use_ports[0 : args.visual_dp]
243+
can_use_ports = can_use_ports[args.visual_dp :]
244+
242245
# 将申请好的端口放入args参数中
243246
args.router_port = router_port
244247
args.detokenization_port = detokenization_port
@@ -436,7 +439,6 @@ def visual_start(args):
436439
metric_port,
437440
) = can_use_ports[0:5]
438441
can_use_ports = can_use_ports[5:]
439-
print(cache_port)
440442

441443
visual_model_tp_ports = []
442444
for _ in range(args.visual_dp):

lightllm/server/config_server/api_http.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from typing import Dict, List
99
from fastapi.responses import JSONResponse
1010
from lightllm.utils.log_utils import init_logger
11-
from ..pd_io_struct import PD_Master_Obj, Visual_Server_Obj
11+
from lightllm.server.visualserver.vit_connect import VIT_Obj
12+
from ..pd_io_struct import PD_Master_Obj
1213
from .nccl_tcp_store import start_tcp_store_server
1314
from lightllm.utils.envs_utils import get_env_start_args
1415
from lightllm.utils.process_check import start_parent_check_thread
@@ -18,7 +19,7 @@
1819
app = FastAPI()
1920

2021
registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
21-
registered_visual_server_obj: Dict[str, Visual_Server_Obj] = {}
22+
registered_visual_server_objs: Dict[str, VIT_Obj] = {}
2223
registered_pd_master_obj_lock = Lock()
2324
registered_visual_server_obj_lock = Lock()
2425

@@ -73,15 +74,15 @@ async def websocket_endpoint(websocket: WebSocket):
7374
return
7475

7576

76-
@app.websocket("/visual_server_register")
77+
@app.websocket("/visual_register")
7778
async def visual_websocket_endpoint(websocket: WebSocket):
7879
await websocket.accept()
7980
client_ip, client_port = websocket.client
8081
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
81-
registered_visual_server_obj: Visual_Server_Obj = pickle.loads(await websocket.receive_bytes())
82+
registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes())
8283
logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}")
8384
with registered_visual_server_obj_lock:
84-
registered_visual_server_obj_lock[registered_visual_server_obj.node_id] = registered_visual_server_obj
85+
registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj
8586

8687
try:
8788
while True:
@@ -93,7 +94,7 @@ async def visual_websocket_endpoint(websocket: WebSocket):
9394
finally:
9495
logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed")
9596
with registered_visual_server_obj_lock:
96-
registered_visual_server_obj.pop(registered_visual_server_obj.node_id, None)
97+
registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None)
9798
return
9899

99100

@@ -105,10 +106,10 @@ async def get_registered_objects():
105106
return {"data": base64_encoded}
106107

107108

108-
@app.get("/registered_visual_server_objects")
109+
@app.get("/registered_visual_objects")
109110
async def get_vit_registered_objects():
110111
with registered_visual_server_obj_lock:
111-
serialized_data = pickle.dumps(registered_visual_server_obj)
112+
serialized_data = pickle.dumps(registered_visual_server_objs)
112113
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
113114
return {"data": base64_encoded}
114115

lightllm/server/httpserver/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ async def handle_loop(self):
697697

698698
asyncio.create_task(pd_handle_loop(self))
699699

700+
if self.enable_multimodal:
701+
asyncio.create_task(self.vit_manager.vit_handle_loop())
702+
700703
while True:
701704
try:
702705
await asyncio.wait_for(self.recv_from_detokenization.recv_pyobj(), timeout=0.05)

lightllm/server/pd_io_struct.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,6 @@ def to_log_str(self):
7373
return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}"
7474

7575

76-
@dataclass
77-
class Visual_Server_Obj:
78-
node_id: int
79-
host_ip_port: str
80-
81-
def to_log_str(self):
82-
return f"Visual_Server host_ip_port: {self.host_ip_port} node_id: {self.node_id}"
83-
84-
8576
@dataclass
8677
class UpKVStatus:
8778
type: str = "kv_move_status"

lightllm/server/visualserver/manager.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def __init__(
5353
def _setup_connections(self):
5454
context = zmq.Context(2)
5555
if self.remote_vit:
56-
self.recv_from_remote_llm = context.socket(zmq.PULL)
57-
self.recv_from_remote_llm.bind(f"tcp://*:{self.args.remote_vit_port}")
56+
self.vit_receiver = context.socket(zmq.PULL)
57+
self.vit_receiver.bind(f"tcp://*:{self.args.remote_vit_port}")
5858
else:
59-
self.recv_from_httpserver = context.socket(zmq.PULL)
60-
self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}")
59+
self.vit_receiver = context.socket(zmq.PULL)
60+
self.vit_receiver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}")
6161
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
6262
self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}")
6363
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
@@ -153,7 +153,7 @@ async def loop_for_fwd(self):
153153

154154
def _recv_reqs(self):
155155
if self.remote_vit:
156-
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
156+
recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK)
157157
for img in recv_req.multimodal_params.images:
158158
image_patch = self.tokenizer.get_image_patch_func(img)
159159
data = img._preload_data
@@ -164,7 +164,7 @@ def _recv_reqs(self):
164164
self.cache_client.root.set_items_data([md5])
165165
return recv_req
166166
else:
167-
return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
167+
return self.vit_receiver.recv_pyobj(zmq.NOBLOCK)
168168

169169
async def loop_for_netio_req(self):
170170
if not hasattr(self, "visual_recv_max_count"):
@@ -173,7 +173,7 @@ async def loop_for_netio_req(self):
173173
while True:
174174
try:
175175
for _ in range(self.visual_recv_max_count):
176-
recv_req: GroupReqIndexes = self._recv_reqs()
176+
recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK)
177177
if isinstance(recv_req, GroupReqIndexes):
178178
self.waiting_reqs.append(recv_req)
179179
else:
@@ -182,6 +182,9 @@ async def loop_for_netio_req(self):
182182
except zmq.ZMQError:
183183
# 当队列已经开始清空的时候,将一次接受数量下调
184184
self.visual_recv_max_count = 64
185+
except Exception as e:
186+
logger.exception(f"Error in loop_for_netio_req: {e}")
187+
raise e
185188
await asyncio.sleep(0.01)
186189

187190
# code for visual only mode
@@ -249,9 +252,6 @@ def handle_exception(loop, context):
249252
loop = asyncio.new_event_loop()
250253
loop.set_exception_handler(handle_exception)
251254
asyncio.set_event_loop(loop)
252-
if args.run_mode == "visual":
253-
loop.create_task(visualserver.loop_for_fwd_visual_only())
254-
else:
255-
loop.create_task(visualserver.loop_for_fwd())
255+
create_forward_loop(args, visualserver, loop)
256256
loop.run_until_complete(visualserver.loop_for_netio_req())
257257
return

lightllm/server/visualserver/register_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def register_loop(args):
2020
while True:
2121

2222
try:
23-
uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_server_register"
23+
uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register"
2424
async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket:
2525

2626
sock = websocket.transport.get_extra_info("socket")
@@ -33,7 +33,7 @@ async def register_loop(args):
3333

3434
while True:
3535
await websocket.send("heartbeat")
36-
await asyncio.sleep(60)
36+
await asyncio.sleep(40)
3737

3838
except Exception as e:
3939
logger.error("connetion to config_server has error")

lightllm/server/visualserver/vit_connect.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _setup_local_vit_connection(self):
5757
logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}")
5858

5959
def _setup_remote_vit_connections(self):
60+
print("_setup_remote_vit_connections", "fdakpgdakgjadpgkjadk")
6061
asyncio.create_task(self.vit_handle_loop())
6162

6263
# wait for remote vit instances
@@ -89,6 +90,7 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL):
8990
raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}")
9091

9192
async def vit_handle_loop(self):
93+
print("vit_handle_loop", "fdakpgdakgjadpgkjadk")
9294
while True:
9395
try:
9496
id_to_vit_obj = await self._get_vit_objs()
@@ -118,8 +120,8 @@ async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]:
118120
get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。
119121
"""
120122
# 使用 config_server 服务来发现所有的 pd_master 节点。
121-
uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit"
122-
123+
uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects"
124+
print("uri", uri)
123125
try:
124126
async with httpx.AsyncClient() as client:
125127
response = await client.get(uri)

0 commit comments

Comments
 (0)