Skip to content

Commit 2ec76a6

Browse files
committed
Merge remote-tracking branch 'origin' into autotuner2
2 parents 2853168 + a771412 commit 2ec76a6

File tree

12 files changed

+230
-116
lines changed

12 files changed

+230
-116
lines changed

lightllm/common/req_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def init_req_sampling_params(self, req):
159159
token_id_counter(
160160
prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx]
161161
)
162+
torch.cuda.current_stream().synchronize()
162163

163164
return
164165

lightllm/server/api_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
144144
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
145145
do not set it and keep the default value as 1.""",
146146
)
147+
parser.add_argument(
148+
"--dp_balancer",
149+
type=str,
150+
default="bs_balancer",
151+
choices=["round_robin", "bs_balancer"],
152+
help="the dp balancer type, default is bs_balancer",
153+
)
147154
parser.add_argument(
148155
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
149156
)

lightllm/server/router/batch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int):
4040
req_list.append(req)
4141
return req_list
4242

43+
def get_all_dp_req_num(self) -> List[int]:
44+
if self.dp_size_in_node == 1:
45+
return [len(self.reqs)]
46+
47+
all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
48+
for req in self.reqs:
49+
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
50+
return all_dp_req_num
51+
4352
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
4453
unfinished_req_ids = []
4554
for req in self.reqs:

lightllm/server/router/manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ async def wait_to_model_ready(self):
197197
return
198198

199199
def _get_schedule_time_interval(self):
200-
if self.running_batch is None:
201-
# 没有运行中的 batch 时,每 10ms 触发一次请求调度
202-
return 0.01
203-
204200
# dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求
205201
return self.schedule_time_interval
206202

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def init_model(self, kvargs):
7575
self.chunked_prefill_size = self.args.chunked_prefill_size
7676
self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs
7777
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
78+
self.batch_max_tokens = self.args.batch_max_tokens
7879
self.eos_id: List[int] = kvargs.get("eos_id", [2])
7980
self.disable_cudagraph = self.args.disable_cudagraph
8081
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
@@ -395,6 +396,7 @@ def _get_classed_reqs(
395396
# 请求,其逻辑是不适合的。
396397
pause_max_req_num = 2
397398
wait_pause_count = 0
399+
prefill_tokens = 0
398400

399401
# 因为会使用到 radix cache 和 mem_manager 的计数信息
400402
# 所以需要加锁保护。
@@ -443,7 +445,10 @@ def _get_classed_reqs(
443445
wait_pause_count += 1
444446
else:
445447
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
448+
if prefill_tokens + token_num > self.batch_max_tokens:
449+
continue
446450
if token_num <= can_alloc_token_num:
451+
prefill_tokens += token_num
447452
prefill_reqs.append(req_obj)
448453
can_alloc_token_num -= token_num
449454
else:

lightllm/server/router/req_queue/base_queue.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None:
2626
self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy
2727
self.router_max_new_token_len = args.router_max_new_token_len
2828

29-
def append(self, req: Req):
30-
req.sample_params.suggested_dp_index = self.dp_index
31-
self.waiting_req_list.append(req)
32-
return
33-
3429
def extend(self, req_group: List[Req]):
3530
for req in req_group:
3631
req.sample_params.suggested_dp_index = self.dp_index
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .roundrobin import RoundRobinDpBalancer
2+
from typing import List
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from .bs import DpBsBalancer
5+
6+
7+
def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
8+
if args.dp_balancer == "round_robin":
9+
return RoundRobinDpBalancer(dp_size_in_node, inner_queues)
10+
elif args.dp_balancer == "bs_balancer":
11+
return DpBsBalancer(dp_size_in_node, inner_queues)
12+
else:
13+
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import random
2+
from abc import ABC, abstractmethod
3+
from typing import List, Union
4+
from lightllm.server.router.req_queue.base_queue import BaseQueue
5+
from lightllm.server.router.batch import Batch, Req
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DpBalancer(ABC):
12+
"""
13+
DP负载均衡器基类
14+
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
self.dp_size_in_node = dp_size_in_node
19+
self.inner_queues = inner_queues
20+
21+
@abstractmethod
22+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
23+
pass
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import random
2+
from typing import List, Union
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from lightllm.server.router.batch import Batch, Req
5+
from lightllm.utils.log_utils import init_logger
6+
from .base import DpBalancer
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DpBsBalancer(DpBalancer):
12+
"""
13+
This balancer is main to balance the batch size of each dp rank.
14+
Because, for dp mode, if it exists a dp rank without any request, it will
15+
padding a request and cause the waste of GPU compute resource.
16+
"""
17+
18+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
19+
super().__init__(dp_size_in_node, inner_queues)
20+
21+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
22+
if len(reqs_waiting_for_dp_index) == 0:
23+
return
24+
# calculate the total load of each dp rank
25+
all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
26+
if current_batch is not None:
27+
all_dp_req_num = current_batch.get_all_dp_req_num()
28+
total_load_per_dp = [
29+
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)
30+
]
31+
for req_group in reqs_waiting_for_dp_index:
32+
# find the dp rank with minimum load
33+
min_load = min(total_load_per_dp)
34+
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load]
35+
suggested_dp_index = random.choice(select_dp_indexes)
36+
37+
# assign the request to the dp rank and update the load count
38+
for req in req_group:
39+
req.sample_params.suggested_dp_index = suggested_dp_index
40+
self.inner_queues[suggested_dp_index].extend(req_group)
41+
# update the load count for this dp rank
42+
total_load_per_dp[suggested_dp_index] += len(req_group)
43+
44+
reqs_waiting_for_dp_index.clear()
45+
return
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import random
2+
from typing import List, Union
3+
from lightllm.server.router.req_queue.base_queue import BaseQueue
4+
from lightllm.server.router.batch import Batch, Req
5+
from lightllm.utils.log_utils import init_logger
6+
from .base import DpBalancer
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class RoundRobinDpBalancer(DpBalancer):
12+
"""
13+
轮询负载均衡器
14+
在队列长度最小的DP中进行轮询选择
15+
"""
16+
17+
def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
18+
super().__init__(dp_size_in_node, inner_queues)
19+
self.pre_select_dp_index = self.dp_size_in_node - 1
20+
21+
def get_suggest_dp_index(
22+
self,
23+
) -> int:
24+
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
25+
select_dp_indexes = [
26+
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
27+
]
28+
29+
# 如果没有可选择的索引,随机选择一个
30+
if not select_dp_indexes:
31+
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1)
32+
return self.pre_select_dp_index
33+
34+
# 轮询选择
35+
for i in range(self.dp_size_in_node):
36+
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
37+
if next_dp_index in select_dp_indexes:
38+
self.pre_select_dp_index = next_dp_index
39+
return self.pre_select_dp_index
40+
41+
self.pre_select_dp_index = random.choice(select_dp_indexes)
42+
return self.pre_select_dp_index
43+
44+
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None:
45+
for req_group in reqs_waiting_for_dp_index:
46+
suggested_dp_index = self.get_suggest_dp_index()
47+
for req in req_group:
48+
req.sample_params.suggested_dp_index = suggested_dp_index
49+
self.inner_queues[suggested_dp_index].extend(req_group)
50+
reqs_waiting_for_dp_index.clear()
51+
return

0 commit comments

Comments
 (0)