Skip to content

Commit 644e802

Browse files
authored
fix tp moe and improve dp router. (#897)
1 parent f7d7e41 commit 644e802

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
from typing import Optional, Tuple, List, Dict, Any
55
from .base_weight import BaseWeight
6-
from lightllm.utils.dist_utils import get_global_rank, get_current_device_id
6+
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
77
from lightllm.common.quantization import Quantcfg
88

99

@@ -37,7 +37,7 @@ def __init__(
3737
self.n_routed_experts = n_routed_experts
3838
self.split_inter_size = split_inter_size
3939
self.data_type_ = data_type
40-
self.tp_rank_ = get_global_rank()
40+
self.tp_rank_ = get_current_rank_in_dp()
4141
self.experts_up_projs = [None] * self.n_routed_experts
4242
self.experts_gate_projs = [None] * self.n_routed_experts
4343
self.experts_up_proj_scales = [None] * self.n_routed_experts

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class DpQueue:
1212
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
1313
self.dp_size_in_node = dp_size_in_node
1414
self.base_queue_class = base_queue_class
15-
self.round_robin_dp_id = 0
15+
self.pre_select_dp_index = self.dp_size_in_node - 1
1616
from lightllm.server.router.manager import RouterManager
1717

1818
self.router: RouterManager = router
@@ -52,8 +52,8 @@ def append(self, req: Req):
5252
suggested_dp_index = req.sample_params.suggested_dp_index
5353
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
5454
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
55-
suggested_dp_index = self.round_robin_dp_id
56-
self.round_robin_dp_id = (self.round_robin_dp_id + 1) % self.dp_size_in_node
55+
suggested_dp_index = self._get_suggest_dp_index()
56+
self.pre_select_dp_index = suggested_dp_index
5757
req.sample_params.suggested_dp_index = suggested_dp_index
5858
self.inner_queues[suggested_dp_index].append(req)
5959
else:
@@ -62,13 +62,12 @@ def append(self, req: Req):
6262

6363
def extend(self, req_group: List[Req]):
6464
# 同一个组的,要分配在同一个 dp 上,效率最高
65-
index = self.round_robin_dp_id
66-
self.round_robin_dp_id = (self.round_robin_dp_id + 1) % self.dp_size_in_node
65+
index = self._get_suggest_dp_index()
6766
for req in req_group:
6867
suggested_dp_index = req.sample_params.suggested_dp_index
6968
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
7069
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
71-
70+
self.pre_select_dp_index = index
7271
req.sample_params.suggested_dp_index = index
7372
self.inner_queues[index].append(req)
7473
else:
@@ -94,3 +93,21 @@ def update_token_load(self, current_batch: Batch, force_update=False):
9493
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index)
9594
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index)
9695
return
96+
97+
def _get_suggest_dp_index(self):
98+
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
99+
select_dp_indexes = [
100+
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
101+
]
102+
103+
# multi thread safe keep
104+
if not select_dp_indexes:
105+
return random.randint(0, self.dp_size_in_node - 1)
106+
107+
# round_robin select.
108+
for i in range(self.dp_size_in_node):
109+
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
110+
if next_dp_index in select_dp_indexes:
111+
return next_dp_index
112+
113+
return random.choice(select_dp_indexes)

0 commit comments

Comments
 (0)