Skip to content

Commit 8eda8ab

Browse files
committed
dp balancer abstract
1 parent 1c16247 commit 8eda8ab

File tree

2 files changed

+11
-25
lines changed

2 files changed

+11
-25
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
112112
help="tool call parser type",
113113
)
114114
parser.add_argument(
115-
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
115+
"--running_max_req_size", type=int, default=2048, help="the max size for forward requests in the same time"
116116
)
117117
parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes")
118118
parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node")
@@ -137,6 +137,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
137137
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
138138
do not set it and keep the default value as 1.""",
139139
)
140+
parser.add_argument(
141+
"--dp_balancer",
142+
type=str,
143+
default="round_robin",
144+
help="the dp balancer type, default is round_robin",
145+
)
140146
parser.add_argument(
141147
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
142148
)

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List
33
from ..batch import Batch, Req
44
from lightllm.server.router.req_queue.base_queue import BaseQueue
5+
from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer
56
from lightllm.common.basemodel.infer_lock import g_router_lock
67
from lightllm.utils.log_utils import init_logger
78

@@ -12,14 +13,13 @@ class DpQueue:
1213
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
1314
self.dp_size_in_node = dp_size_in_node
1415
self.base_queue_class = base_queue_class
15-
self.pre_select_dp_index = self.dp_size_in_node - 1
1616
from lightllm.server.router.manager import RouterManager
1717

1818
self.router: RouterManager = router
1919
self.inner_queues: List[BaseQueue] = [
2020
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
2121
]
22-
22+
self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues)
2323
return
2424

2525
def get_dp_queue(self, dp_index: int):
@@ -49,8 +49,7 @@ def append(self, req: Req):
4949
suggested_dp_index = req.sample_params.suggested_dp_index
5050
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
5151
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
52-
suggested_dp_index = self._get_suggest_dp_index()
53-
self.pre_select_dp_index = suggested_dp_index
52+
suggested_dp_index = self.dp_balancer.get_suggest_dp_index()
5453
req.sample_params.suggested_dp_index = suggested_dp_index
5554
self.inner_queues[suggested_dp_index].append(req)
5655
else:
@@ -59,12 +58,11 @@ def append(self, req: Req):
5958

6059
def extend(self, req_group: List[Req]):
6160
# 同一个组的,要分配在同一个 dp 上,效率最高
62-
index = self._get_suggest_dp_index()
61+
index = self.dp_balancer.get_suggest_dp_index()
6362
for req in req_group:
6463
suggested_dp_index = req.sample_params.suggested_dp_index
6564
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
6665
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
67-
self.pre_select_dp_index = index
6866
req.sample_params.suggested_dp_index = index
6967
self.inner_queues[index].append(req)
7068
else:
@@ -87,21 +85,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
8785
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index)
8886
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index)
8987
return
90-
91-
def _get_suggest_dp_index(self):
92-
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
93-
select_dp_indexes = [
94-
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
95-
]
96-
97-
# multi thread safe keep
98-
if not select_dp_indexes:
99-
return random.randint(0, self.dp_size_in_node - 1)
100-
101-
# round_robin select.
102-
for i in range(self.dp_size_in_node):
103-
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
104-
if next_dp_index in select_dp_indexes:
105-
return next_dp_index
106-
107-
return random.choice(select_dp_indexes)

0 commit comments

Comments
 (0)