Skip to content

Commit 3fd6e48

Browse files
committed
fix multinode abort
1 parent b3de424 commit 3fd6e48

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

lightllm/server/httpserver/manager.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ def __init__(
5050
context = zmq.asyncio.Context(2)
5151
self.send_to_router = context.socket(zmq.PUSH)
5252
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}")
53-
53+
5454
self.multinode_req_manager = None
5555
self.child_node_events = {}
5656
self.waiting_objs = []
5757
self.child_node_lock = asyncio.Lock()
58+
self.nnodes = args.nnodes
5859
if args.nnodes > 1:
5960
if args.node_rank == 0:
6061
self.multinode_req_manager = []
@@ -64,12 +65,16 @@ def __init__(
6465
context = zmq.asyncio.Context(2)
6566
self.multinode_req_manager.append(context.socket(zmq.PUSH))
6667
self.multinode_req_manager[-1].connect(f"tcp://{child_ip}:{args.multinode_httpmanager_port}")
67-
logger.info(f"HttpServerManager connected to child node at {child_ip}:{args.multinode_httpmanager_port}")
68+
logger.info(
69+
f"HttpServerManager connected to child node at {child_ip}:{args.multinode_httpmanager_port}"
70+
)
6871
else:
6972
context = zmq.asyncio.Context(2)
7073
self.multinode_req_manager = context.socket(zmq.PULL)
7174
self.multinode_req_manager.bind(f"tcp://*:{args.multinode_httpmanager_port}")
72-
logger.info(f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}")
75+
logger.info(
76+
f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}"
77+
)
7378

7479
self.enable_multimodal = enable_multimodal
7580
if self.enable_multimodal:
@@ -145,16 +150,26 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
145150
def tokens(self, prompt):
146151
prompt_ids = self.tokenizer.encode(prompt)
147152
return len(prompt_ids)
148-
153+
149154
async def loop_for_request(self):
150155
assert self.args.node_rank > 0
151156
tasks = []
152157
while True:
153-
request_id, prompt, sampling_params, multimodal_params, request_headers = await self.multinode_req_manager.recv_pyobj()
154-
results_generator = self.generate(prompt, sampling_params, multimodal_params, None, request_headers, request_id)
158+
(
159+
request_id,
160+
prompt,
161+
sampling_params,
162+
multimodal_params,
163+
request_headers,
164+
) = await self.multinode_req_manager.recv_pyobj()
165+
results_generator = self.generate(
166+
prompt, sampling_params, multimodal_params, None, request_headers, request_id
167+
)
168+
155169
async def generate_wrapper(results_generator):
156170
async for _, _, _, _ in results_generator:
157171
pass
172+
158173
tasks.append(asyncio.create_task(generate_wrapper(results_generator)))
159174
# cleanup
160175
while len(tasks) > 0 and tasks[0].done():
@@ -166,7 +181,7 @@ async def generate(
166181
sampling_params: SamplingParams,
167182
multimodal_params: MultimodalParams,
168183
request: Request,
169-
request_headers = None,
184+
request_headers=None,
170185
multinode_remote_request_id: Optional[int] = None,
171186
) -> Tuple[int, str, dict, FinishStatus]:
172187
start_time = time.time()
@@ -178,7 +193,10 @@ async def generate(
178193
if multinode_remote_request_id is None:
179194
group_request_id = self.id_gen.generate_id()
180195
for sender in self.multinode_req_manager:
181-
sender.send_pyobj((group_request_id, prompt, sampling_params, multimodal_params, request_headers), protocol=pickle.HIGHEST_PROTOCOL)
196+
sender.send_pyobj(
197+
(group_request_id, prompt, sampling_params, multimodal_params, request_headers),
198+
protocol=pickle.HIGHEST_PROTOCOL,
199+
)
182200
else:
183201
group_request_id = multinode_remote_request_id
184202
sampling_params.group_request_id = group_request_id
@@ -238,8 +256,12 @@ async def generate(
238256
await self.transfer_to_next_module(req_status.group_req_objs)
239257

240258
results_generator = self._wait_to_token_package(
241-
start_time, prompt_ids, group_request_id, sampling_params, req_status, # request,
242-
request_headers,
259+
start_time,
260+
prompt_ids,
261+
group_request_id,
262+
sampling_params,
263+
req_status,
264+
request,
243265
)
244266
async for sub_req_id, request_output, metadata, finish_status in results_generator:
245267
# p d 模式下,将 token 数据放入到转发队列中
@@ -368,8 +390,7 @@ async def _wait_to_token_package(
368390
group_request_id: int,
369391
sampling_params: SamplingParams,
370392
req_status: "ReqStatus",
371-
request_headers,
372-
# request: Request,
393+
request: Request,
373394
):
374395

375396
event = req_status.event
@@ -385,10 +406,9 @@ async def _wait_to_token_package(
385406
except asyncio.TimeoutError:
386407
pass
387408

388-
# TODO: abort() for multinode
389-
# if request is not None and await request.is_disconnected():
390-
# await self.abort(group_request_id)
391-
# raise Exception(f"req_id {group_request_id} disconnected")
409+
if request is not None and await request.is_disconnected() and self.nnodes == 1:
410+
await self.abort(group_request_id)
411+
raise Exception(f"req_id {group_request_id} disconnected")
392412

393413
async with req_status.lock:
394414
event.clear()
@@ -416,13 +436,12 @@ async def _wait_to_token_package(
416436
unfinished_count -= 1
417437

418438
# 所有子请求完成后,就删除占用的资源
419-
if unfinished_count == 0:
439+
if unfinished_count == 0 and request is not None:
420440
total_cost_time_ms = (time.time() - start_time) * 1000
421441
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
422442
self.per_token_costs.add(mean_per_token_cost_time_ms)
423-
x_request_id = request_headers.get("X-Request-Id", "") if request_headers is not None else ""
424-
x_session_id = request_headers.get("X-Session-Id", "") if request_headers is not None else ""
425-
443+
x_request_id = request.headers.get("X-Request-Id", "") if request is not None else ""
444+
x_session_id = request.headers.get("X-Session-Id", "") if request is not None else ""
426445
prompt_cache_ratio = prompt_cache_len / prompt_tokens
427446
self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len)
428447
self.metric_client.histogram_observe("lightllm_cache_ratio", prompt_cache_ratio)
@@ -506,7 +525,7 @@ async def handle_loop(self):
506525
if self.pd_mode.is_P_or_D():
507526
self.forwarding_queue = AsyncQueue()
508527
asyncio.create_task(self.pd_handle_loop())
509-
528+
510529
if self.args.node_rank > 0:
511530
asyncio.create_task(self.loop_for_request())
512531

0 commit comments

Comments
 (0)