Skip to content

Commit 86df27c

Browse files
committed
add dp balancer for dp
1 parent 8eda8ab commit 86df27c

File tree

5 files changed

+169
-18
lines changed

5 files changed

+169
-18
lines changed

lightllm/server/router/batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int):
4040
req_list.append(req)
4141
return req_list
4242

43+
def get_all_dp_req_num(self) -> List[int]:
44+
if self.dp_size_in_node == 1:
45+
return [len(self.reqs)]
46+
47+
all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
48+
for req in self.reqs:
49+
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
50+
return all_dp_req_num
51+
4352
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
4453
unfinished_req_ids = []
4554
for req in self.reqs:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .dp_base_balancer import RoundRobinDpBalancer
2+
from typing import List
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from .dp_balancer_for_pd import DpBalancerForPd
5+
6+
7+
def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
8+
if args.dp_balancer == "round_robin":
9+
return DpBalancerForPd(dp_size_in_node, inner_queues)
10+
if args.run_mode in ["prefill", "decode"]:
11+
return DpBalancerForPd(dp_size_in_node, inner_queues)
12+
else:
13+
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}")
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import List, Union
2+
from lightllm.server.router.req_queue.base_queue import BaseQueue
3+
from lightllm.server.router.batch import Batch, Req
4+
from lightllm.utils.log_utils import init_logger
5+
from .dp_base_balancer import DpBalancer
6+
7+
logger = init_logger(__name__)
8+
9+
10+
class DpBalancerForPd(DpBalancer):
11+
"""
12+
This balancer is main to balance the batch size of each dp rank.
13+
Because, for dp mode, if it exists a dp rank without any request, it will
14+
padding a request and cause the waste of GPU compute resource.
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
super().__init__(dp_size_in_node, inner_queues)
19+
20+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
21+
if len(reqs_waiting_for_dp_index) == 0:
22+
return
23+
# calculate the total load of each dp rank
24+
if current_batch is not None:
25+
all_dp_req_num = current_batch.get_all_dp_req_num()
26+
total_load_per_dp = [
27+
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)
28+
]
29+
else:
30+
total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)]
31+
for req_group in reqs_waiting_for_dp_index:
32+
# calculate the length of this request group
33+
if isinstance(req_group, list):
34+
req_length = len(req_group)
35+
else:
36+
req_length = 1
37+
38+
# find the dp rank with minimum load
39+
min_load = min(total_load_per_dp)
40+
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load]
41+
42+
# select the dp rank with the minimum load
43+
if len(select_dp_indexes) == 1:
44+
suggested_dp_index = select_dp_indexes[0]
45+
else:
46+
# if multiple dp ranks have the same minimum load, randomly select one
47+
import random
48+
49+
suggested_dp_index = random.choice(select_dp_indexes)
50+
51+
# assign the request to the dp rank and update the load count
52+
if not isinstance(req_group, list):
53+
req_group = [req_group]
54+
55+
for req in req_group:
56+
req.sample_params.suggested_dp_index = suggested_dp_index
57+
self.inner_queues[suggested_dp_index].append(req)
58+
59+
# update the load count for this dp rank
60+
total_load_per_dp[suggested_dp_index] += req_length
61+
62+
reqs_waiting_for_dp_index.clear()
63+
return
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import random
2+
from abc import ABC, abstractmethod
3+
from typing import List, Union
4+
from lightllm.server.router.req_queue.base_queue import BaseQueue
5+
from lightllm.server.router.batch import Batch, Req
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DpBalancer(ABC):
12+
"""
13+
DP负载均衡器基类
14+
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
self.dp_size_in_node = dp_size_in_node
19+
self.inner_queues = inner_queues
20+
self.pre_select_dp_index = self.dp_size_in_node - 1
21+
22+
@abstractmethod
23+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
24+
pass
25+
26+
27+
class RoundRobinDpBalancer(DpBalancer):
28+
"""
29+
轮询负载均衡器
30+
在队列长度最小的DP中进行轮询选择
31+
"""
32+
33+
def get_suggest_dp_index(
34+
self,
35+
) -> int:
36+
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
37+
select_dp_indexes = [
38+
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
39+
]
40+
41+
# 如果没有可选择的索引,随机选择一个
42+
if not select_dp_indexes:
43+
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1)
44+
return self.pre_select_dp_index
45+
46+
# 轮询选择
47+
for i in range(self.dp_size_in_node):
48+
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
49+
if next_dp_index in select_dp_indexes:
50+
self.pre_select_dp_index = next_dp_index
51+
return self.pre_select_dp_index
52+
53+
self.pre_select_dp_index = random.choice(select_dp_indexes)
54+
return self.pre_select_dp_index
55+
56+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
57+
for req_group in reqs_waiting_for_dp_index:
58+
suggested_dp_index = self.get_suggest_dp_index()
59+
if not isinstance(req_group, list):
60+
req_group = [req_group]
61+
for req in req_group:
62+
req.sample_params.suggested_dp_index = suggested_dp_index
63+
self.inner_queues[suggested_dp_index].append(req)
64+
reqs_waiting_for_dp_index.clear()
65+
return

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
2020
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
2121
]
2222
self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues)
23+
self.reqs_waiting_for_dp_index = []
2324
return
2425

2526
def get_dp_queue(self, dp_index: int):
@@ -31,10 +32,16 @@ def get_wait_req_num(self):
3132

3233
# @calculate_time(show=True, min_cost_ms=10)
3334
def generate_new_batch(self, current_batch: Batch):
34-
batches = [
35-
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node)
36-
]
37-
return self._merge_batch(batches)
35+
try:
36+
self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index)
37+
batches = [
38+
self.inner_queues[dp_index].generate_new_batch(current_batch)
39+
for dp_index in range(self.dp_size_in_node)
40+
]
41+
return self._merge_batch(batches)
42+
except Exception as e:
43+
logger.error(f"generate new batch failed: {e}")
44+
raise e
3845

3946
def _merge_batch(self, dp_batches: List[Batch]):
4047
merged_batch: Batch = None
@@ -48,26 +55,20 @@ def _merge_batch(self, dp_batches: List[Batch]):
4855
def append(self, req: Req):
4956
suggested_dp_index = req.sample_params.suggested_dp_index
5057
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
51-
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
52-
suggested_dp_index = self.dp_balancer.get_suggest_dp_index()
53-
req.sample_params.suggested_dp_index = suggested_dp_index
54-
self.inner_queues[suggested_dp_index].append(req)
58+
# 在调度时,统一分配请求id
59+
self.reqs_waiting_for_dp_index.append(req)
5560
else:
5661
self.inner_queues[suggested_dp_index].append(req)
5762
return
5863

5964
def extend(self, req_group: List[Req]):
60-
# 同一个组的,要分配在同一个 dp 上,效率最高
61-
index = self.dp_balancer.get_suggest_dp_index()
62-
for req in req_group:
63-
suggested_dp_index = req.sample_params.suggested_dp_index
64-
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
65-
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
66-
req.sample_params.suggested_dp_index = index
67-
self.inner_queues[index].append(req)
68-
else:
65+
suggested_dp_index = req_group[0].sample_params.suggested_dp_index
66+
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
67+
# 同一个组的,要分配在同一个 dp 上
68+
self.reqs_waiting_for_dp_index.append(req_group)
69+
else:
70+
for req in req_group:
6971
self.inner_queues[suggested_dp_index].append(req)
70-
7172
return
7273

7374
def is_busy(self):

0 commit comments

Comments
 (0)