Skip to content

Commit 2c06cf4

Browse files
author
Weichao Luo
committed
support dp for pd_nixl
1 parent 3838dc8 commit 2c06cf4

File tree

10 files changed

+352
-61
lines changed

10 files changed

+352
-61
lines changed

lightllm/server/pd_io_struct.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ class RemotePrefillServerInfo:
127127
prefill_server_ip: str
128128
prefill_server_port: int
129129

130+
@dataclass
131+
class DistInfo:
132+
world_size: int
133+
nnodes: int
134+
dp_size: int
135+
dp_world_size: int
136+
dp_size_in_node: int
137+
node_world_size: int
130138

131139
@dataclass
132140
class PDTransLeaveInfo:

lightllm/server/router/manager.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from lightllm.utils.graceful_utils import graceful_registry
3535
from lightllm.utils.process_check import start_parent_check_thread
3636
from lightllm.utils.envs_utils import get_unique_server_name
37+
from lightllm.server.pd_io_struct import DistInfo
3738

3839
logger = init_logger(__name__)
3940

@@ -49,6 +50,7 @@ def __init__(self, args, router_port, detokenization_port, metric_port):
4950
self.dp_size = args.dp
5051
# 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
5152
self.dp_size_in_node = max(1, args.dp // self.nnodes)
53+
self.dp_world_size = self.world_size // self.dp_size
5254
self.is_multinode_tp = args.nnodes > 1 and args.dp == 1
5355
self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1
5456
# 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
@@ -116,9 +118,9 @@ async def wait_to_model_ready(self):
116118
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
117119
self.info_queue: mp.Queue = mp.Queue()
118120
self.mem_queues: List[torch.multiprocessing.Queue] = [
119-
torch.multiprocessing.Queue() for _ in range(self.world_size)
121+
torch.multiprocessing.Queue() for _ in range(self.node_world_size)
120122
]
121-
self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.world_size)]
123+
self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.node_world_size)]
122124
self.rpc_event = multiprocessing.Event()
123125
self.rpc_finished_event = multiprocessing.Event()
124126

@@ -134,8 +136,8 @@ async def wait_to_model_ready(self):
134136
rpc_event=self.rpc_event,
135137
rpc_finished_event=self.rpc_finished_event,
136138
info_queue=self.info_queue,
137-
result_queue=self.result_queues[rank_id],
138-
mem_queue=self.mem_queues[rank_id],
139+
result_queue=self.result_queues[rank_id % node_world_size],
140+
mem_queue=self.mem_queues[rank_id % node_world_size],
139141
router_lock=self.router_lock,
140142
)
141143
self.model_rpc_servers.append(rpc_model)
@@ -190,7 +192,7 @@ async def wait_to_model_ready(self):
190192
get_unique_server_name(),
191193
self.max_total_token_num,
192194
node_world_size=self.node_world_size,
193-
dp_world_size=self.world_size // self.dp_size,
195+
dp_world_size=self.dp_world_size,
194196
)
195197
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
196198
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
@@ -208,8 +210,12 @@ async def wait_to_model_ready(self):
208210
start_pd_remote_prefill_server_process,
209211
)
210212

213+
dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size,
214+
self.dp_world_size, self.dp_size_in_node, self.node_world_size)
215+
211216
start_pd_remote_prefill_server_process(
212217
self.args.pd_node_id,
218+
dist_info = dist_info,
213219
http_server_port=self.args.pd_remote_prefill_http_port,
214220
server_port=self.args.pd_remote_prefill_port,
215221
from_backend_queue=self.info_queue,
@@ -229,9 +235,12 @@ async def wait_to_model_ready(self):
229235
from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import (
230236
start_pd_remote_prefill_client_process,
231237
)
238+
dist_info = DistInfo(self.world_size, self.nnodes, self.dp_size,
239+
self.dp_world_size, self.dp_size_in_node, self.node_world_size)
232240

233241
start_pd_remote_prefill_client_process(
234242
self.args.pd_node_id,
243+
dist_info,
235244
from_backend_queue=self.info_queue,
236245
to_backend_queues=self.result_queues,
237246
agent_meta_queues=self.mem_queues,
@@ -246,7 +255,7 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
246255
req.multimodal_params = group_req_indexes.multimodal_params
247256
req.start_time = group_req_indexes.time_mark
248257
if isinstance(req, PDChunkedPrefillReq):
249-
req.dp_world_size = self.world_size
258+
req.dp_world_size = self.dp_world_size
250259
req_group.append(req)
251260

252261
logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode
1515
from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode
1616
from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode
17+
from .pd_nixl.impl_for_pd_decode_dp import PDNIXLDPBackendForDecodeNode
18+
from .pd_nixl.impl_for_pd_prefill_dp import PDNIXLDPBackendForPrefillNode

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_
4141
self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict()
4242

4343
def init_custom(self):
44-
self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.tp_rank)
44+
self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.rank_in_node)
4545
self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer)
4646
self.nixl_meta_queue.put(
4747
(self.nixl_agent.agent_metadata, self.nixl_agent.num_tokens, self.nixl_agent.local_mem_desc)
@@ -243,11 +243,11 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]):
243243
nopad_b_start_loc.append(start_loc) # last request
244244

245245
input_ids = np.concatenate(input_ids, dtype=np.int64)
246-
# g_infer_state_lock.acquire() # I don't think it's needed
246+
247247
if g_infer_context.radix_cache is not None:
248248
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
249249
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0])
250-
# g_infer_state_lock.release()
250+
251251
kwargs = {
252252
"batch_size": len(run_reqs),
253253
"input_ids": input_ids,

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def decode(self):
8181
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
8282

8383
if decode_reqs:
84-
# print(f"decode req: {self.rank_in_dp}: {len(decode_reqs)}")
8584
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
8685
logits = self.model.forward(**kwargs)
8786

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import time
2+
import torch
3+
import torch.multiprocessing as mp
4+
import torch.distributed as dist
5+
from typing import List
6+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
7+
from lightllm.server.core.objs.req import PDChunkedPrefillReq
8+
from lightllm.utils.log_utils import init_logger
9+
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
10+
from lightllm.utils.envs_utils import get_env_start_args
11+
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs
12+
13+
from .impl_for_pd_decode import PDNIXLBackendForDecodeNode
14+
15+
logger = init_logger(__name__)
16+
17+
18+
class PDNIXLDPBackendForDecodeNode(PDNIXLBackendForDecodeNode):
19+
def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None:
20+
super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue)
21+
self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap
22+
23+
def init_custom(self):
24+
super().init_custom()
25+
26+
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
27+
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs
28+
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal)
29+
self.model.forward(**kwargs)
30+
assert len(run_reqs) == 0 and padded_req_num == 1
31+
32+
return
33+
34+
def decode(self):
35+
36+
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
37+
g_infer_context.infer_req_ids,
38+
no_decode=False,
39+
)
40+
# filter out remote prefilling reqs
41+
prefill_reqs, aborted_reqs, decode_reqs, _ = self._decode_filter_reqs(prefill_reqs, aborted_reqs, decode_reqs)
42+
43+
self._filter_reqs(aborted_reqs)
44+
45+
# allocate kv cache, do remote prefill
46+
if prefill_reqs:
47+
# TODO: we could allocate cache later after remote prefill done and get a signal from remote
48+
# but it will have a risk to not have enough cache for this request.
49+
kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs)
50+
for idx, run_req in enumerate(run_reqs):
51+
run_req: InferReq = run_req
52+
shm_req: PDChunkedPrefillReq = run_req.shm_req
53+
# forward each req to remote prefill
54+
# since the token index are the same across TPs, we only need to trigger prefill on master
55+
if self.is_master_in_dp:
56+
run_req.remote_prefill_start = time.time()
57+
self.to_remote_queue.put(self._build_remote_prefill_task(idx, kwargs, run_req))
58+
59+
shm_req.set_pd_req_rank_state(self.rank_in_dp, 0) # set in progress state
60+
run_req.in_prefill_or_transfer = True
61+
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
62+
63+
self.reduce_tensor.fill_(len(decode_reqs))
64+
dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX)
65+
max_decode_num = self.reduce_tensor.item()
66+
if max_decode_num != 0:
67+
if not self.enable_decode_microbatch_overlap:
68+
self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
69+
else:
70+
self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
71+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
72+
return
73+
74+
def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
75+
76+
kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs(
77+
decode_reqs, max_decode_num, is_multimodal=self.is_multimodal
78+
)
79+
logits = self.model.forward(**kwargs)
80+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
81+
if len(run_reqs) != 0:
82+
logits = logits[0 : len(run_reqs), :]
83+
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
84+
next_token_ids = next_token_ids.detach().cpu().numpy()
85+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
86+
self._post_handle(
87+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
88+
)
89+
return
90+
91+
def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
92+
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import (
93+
padded_overlap_prepare_decode_inputs,
94+
)
95+
96+
(
97+
micro_batch,
98+
run_reqs,
99+
padded_req_num,
100+
micro_batch1,
101+
run_reqs1,
102+
padded_req_num1,
103+
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
104+
105+
logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1)
106+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
107+
req_num, req_num1 = len(run_reqs), len(run_reqs1)
108+
all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device)
109+
110+
all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True)
111+
all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
112+
113+
all_run_reqs = run_reqs + run_reqs1
114+
if all_run_reqs:
115+
next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id)
116+
next_token_ids = next_token_ids.detach().cpu().numpy()
117+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
118+
self._post_handle(
119+
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
120+
)
121+
return
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import threading
2+
import torch
3+
import torch.multiprocessing as mp
4+
import torch.distributed as dist
5+
from typing import List, Tuple
6+
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
7+
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
8+
from lightllm.utils.log_utils import init_logger
9+
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs
10+
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
11+
from lightllm.utils.envs_utils import get_env_start_args
12+
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs
13+
14+
from .impl_for_pd_base import PDNIXLBackendBase
15+
from .impl_for_pd_prefill import PDNIXLBackendForPrefillNode
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class PDNIXLDPBackendForPrefillNode(PDNIXLBackendForPrefillNode):
21+
def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None:
22+
super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue)
23+
self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap
24+
25+
def init_custom(self):
26+
super().init_custom()
27+
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
28+
return
29+
30+
def decode(self):
31+
uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
32+
g_infer_context.infer_req_ids,
33+
no_decode=True,
34+
)
35+
36+
ok_finished_reqs, aborted_reqs, _ = self._prefill_filter_reqs(ok_finished_reqs, aborted_reqs)
37+
38+
assert len(uinit_reqs) == 0
39+
assert len(decode_reqs) == 0
40+
41+
self._prefill_abort_remote(aborted_reqs)
42+
self._filter_reqs(aborted_reqs)
43+
44+
if ok_finished_reqs:
45+
for req in ok_finished_reqs:
46+
self._transfer_kv_to_remote(req)
47+
self._filter_reqs(ok_finished_reqs)
48+
ok_finished_reqs.clear()
49+
50+
current_dp_prefill_num = len(prefill_reqs)
51+
self.reduce_tensor.fill_(current_dp_prefill_num)
52+
dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
53+
max_prefill_num = self.reduce_tensor.item()
54+
if max_prefill_num != 0:
55+
if not self.enable_prefill_microbatch_overlap:
56+
self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs)
57+
else:
58+
self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uinit_reqs, ok_finished_reqs)
59+
60+
self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
61+
return
62+
63+
def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
64+
65+
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs(
66+
prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal
67+
)
68+
logits = self.model.forward(**kwargs)
69+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
70+
if len(run_reqs) != 0:
71+
logits = logits[0 : len(run_reqs), :]
72+
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
73+
next_token_ids = next_token_ids.detach().cpu().numpy()
74+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
75+
self._post_handle(
76+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
77+
extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req),
78+
)
79+
80+
def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
81+
from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import (
82+
padded_overlap_prepare_prefill_inputs,
83+
)
84+
85+
(
86+
micro_batch,
87+
run_reqs,
88+
padded_req_num,
89+
micro_batch1,
90+
run_reqs1,
91+
padded_req_num1,
92+
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal)
93+
logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1)
94+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
95+
req_num, req_num1 = len(run_reqs), len(run_reqs1)
96+
all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device)
97+
98+
all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True)
99+
all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
100+
101+
all_run_reqs = run_reqs + run_reqs1
102+
if all_run_reqs:
103+
next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id)
104+
next_token_ids = next_token_ids.detach().cpu().numpy()
105+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
106+
self._post_handle(
107+
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
108+
extra_post_req_handle_func=lambda req, _1, _2: self._transfer_kv_to_remote(req),
109+
)

0 commit comments

Comments
 (0)