Skip to content

Commit ccbeea9

Browse files
authored
fix
1 parent c48e657 commit ccbeea9

File tree

3 files changed

+7
-13
lines changed

3 files changed

+7
-13
lines changed

lightllm/server/api_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def set_args(self, args):
101101
enable_multimodal=args.enable_multimodal,
102102
metric_port=args.metric_port,
103103
)
104-
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机存粹tp的运行模式,这时候 1 // 2 == 0, 需要兼容
104+
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
105105
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
106106

107107

lightllm/server/router/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@ def _can_decode(self, batch: Batch):
389389
# p d 分离模式下,目前只能使用保守调度,保证请求放入进行decode的时候
390390
# 显存token肯定是够用的。
391391
# deepseekv2 dp 模式下,采用保守调度,也肯定够用
392-
if self.is_pd_run_mode or self.dp_size > 1 or self.is_safe_schedule:
392+
if self.is_pd_run_mode or self.dp_size_in_node > 1 or self.is_safe_schedule:
393393
return True
394394

395395
# 下面的判定条件,只在 dp 为 1 的情况下启用
396-
assert self.dp_size == 1
396+
assert self.dp_size_in_node == 1
397397
return batch.get_batch_decode_need_tokens()[0] + self.get_used_tokens(0) <= self.max_total_token_num
398398

399399
def get_used_tokens(self, dp_index):

lightllm/server/router/req_queue/dp_base_queue.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,8 @@ def _merge_batch(self, dp_batches: List[Batch]):
4444

4545
def append(self, req: Req):
4646
suggested_dp_index = req.sample_params.suggested_dp_index
47-
if suggested_dp_index is None or suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
48-
if suggested_dp_index is not None and (
49-
suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0
50-
):
51-
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
47+
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
48+
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
5249
suggested_dp_index = random.randint(0, self.dp_size_in_node - 1)
5350
req.sample_params.suggested_dp_index = suggested_dp_index
5451
self.inner_queues[suggested_dp_index].append(req)
@@ -61,11 +58,8 @@ def extend(self, req_group: List[Req]):
6158
index = random.randint(0, self.dp_size_in_node - 1)
6259
for req in req_group:
6360
suggested_dp_index = req.sample_params.suggested_dp_index
64-
if suggested_dp_index is None or suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
65-
if suggested_dp_index is not None and (
66-
suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0
67-
):
68-
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
61+
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
62+
logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error")
6963
req.sample_params.suggested_dp_index = index
7064
self.inner_queues[index].append(req)
7165
else:

0 commit comments

Comments
 (0)