-
Notifications
You must be signed in to change notification settings - Fork 293
【Feature】 dp chuncked prefill balance. #937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,13 @@ | ||
| import os | ||
| import asyncio | ||
| import numpy as np | ||
| import rpyc | ||
| import torch | ||
| import socket | ||
| from datetime import timedelta | ||
| from typing import Dict, List, Tuple, Callable, Optional | ||
| from typing import List, Tuple, Callable, Optional | ||
| from transformers.configuration_utils import PretrainedConfig | ||
| from lightllm.utils.infer_utils import set_random_seed | ||
| from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end | ||
| from lightllm.utils.log_utils import init_logger | ||
| from lightllm.models import get_model | ||
| from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache | ||
| from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams | ||
| from lightllm.server.router.model_infer.infer_batch import InferReq | ||
| from lightllm.server.router.token_load import TokenLoad | ||
| from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock | ||
| from lightllm.common.basemodel.basemodel import TpPartBaseModel | ||
|
|
@@ -26,12 +21,18 @@ | |
| from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size | ||
| from lightllm.utils.dist_utils import get_dp_rank_in_node | ||
| from lightllm.distributed import dist_group_manager | ||
| from .chuncked_prefill_state import ChunkedPrefillState | ||
| import torch.distributed as dist | ||
|
|
||
|
|
||
| class ModeBackend: | ||
| def __init__(self) -> None: | ||
| self.shm_req_manager = ShmReqManager() | ||
|
|
||
| # 当子类处于chuncked prefill 相关模式时,会使用该管理变量进行一些 chuncked prefill | ||
| # 的推理控制,具体使用方式可以参考 ChunkedPrefillBackend 类中的使用方式。如果是非 | ||
| # chuncked prefill 相关的模式,该状态变量不会生效。 | ||
| self.chunked_prefill_state = ChunkedPrefillState() | ||
| pass | ||
|
|
||
| def init_model(self, kvargs): | ||
|
|
@@ -56,7 +57,6 @@ def init_model(self, kvargs): | |
| self.eos_id: List[int] = kvargs.get("eos_id", [2]) | ||
| self.disable_cudagraph = self.args.disable_cudagraph | ||
|
|
||
| self.cache = {} | ||
| self.logger = init_logger(__name__) | ||
|
|
||
| self.weight_dir = kvargs["weight_dir"] | ||
|
|
@@ -140,6 +140,14 @@ def init_model(self, kvargs): | |
| vocab_size=self.model.vocab_size, | ||
| ) | ||
|
|
||
| # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 | ||
| if self.dp_size > 1: | ||
| self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) | ||
| self.dp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) | ||
| self.dp_all_gather_tensor = torch.tensor( | ||
| [0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False | ||
| ) | ||
|
|
||
| self.init_custom() | ||
| return | ||
|
|
||
|
|
@@ -446,6 +454,27 @@ def _update_reqs_mtp_gen_token_ids(self, reqs: List[InferReq], mtp_draft_next_to | |
| req.mtp_gen_token_ids.append(token_id) | ||
| return | ||
|
|
||
| def _dp_all_gather_prefill_req_num(self, prefill_reqs: List[InferReq]) -> Tuple[np.ndarray, int]: | ||
| """ | ||
| Gather the number of prefill requests across all DP ranks. | ||
| """ | ||
| current_dp_prefill_num = len(prefill_reqs) | ||
| self.dp_gather_item_tensor.fill_(current_dp_prefill_num) | ||
| dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) | ||
| dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy() | ||
| max_prefill_num = np.max(dp_prefill_req_nums) | ||
| return dp_prefill_req_nums, max_prefill_num | ||
|
Comment on lines
454
to
+466
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int: | ||
| """ | ||
| Reduce the number of decode requests across all DP ranks. | ||
| """ | ||
| current_dp_decode_num = len(decode_reqs) | ||
| self.dp_reduce_tensor.fill_(current_dp_decode_num) | ||
| dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) | ||
| max_decode_num = self.dp_reduce_tensor.item() | ||
|
Comment on lines
+472
to
+475
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return max_decode_num | ||
|
|
||
| def preload_prompt_cache_kv_buffer(self, model_cfg): | ||
| self.logger.info("Preload prompt cache kv buffer.") | ||
| cur_rank = dist.get_rank() | ||
|
|
||
89 changes: 89 additions & 0 deletions
89
lightllm/server/router/model_infer/mode_backend/chuncked_prefill_state.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| import dataclasses | ||
| import numpy as np | ||
| from typing import List | ||
| from lightllm.utils.envs_utils import get_env_start_args | ||
| from ..infer_batch import InferReq | ||
| from lightllm.utils.log_utils import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class ChunkedPrefillState: | ||
| """ | ||
| 用于保存和控制 chuncked prefill 推理调度的控制状态,因为不同的场景,对于首字和包间的 | ||
| 诉求不是特别一致,所以需要一个状态来控制。主要通过 args.router_max_wait_tokens 参数 | ||
| 可以调节 prefill 的激进程度和方式,来协调首字和包间的平衡。 | ||
| """ | ||
|
|
||
| prefill_wait_step: int = 0 | ||
| need_prefill_count: int = 0 | ||
| current_wait_step: int = 0 | ||
|
|
||
| # dp chuncked prefill 的等待步数参数 | ||
| dp_prefill_wait_step: int = 0 | ||
| dp_current_wait_step: int = 0 | ||
|
|
||
| # world_size | ||
| _global_world_size: int = 0 | ||
|
|
||
| def __post_init__(self): | ||
| args = get_env_start_args() | ||
| self.prefill_wait_step = args.router_max_wait_tokens | ||
| self.dp_prefill_wait_step = args.dp_prefill_wait_step | ||
| self._global_world_size = args.tp | ||
| return | ||
|
|
||
| def need_prefill(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]) -> bool: | ||
| no_decode_reqs = len(decode_reqs) == 0 | ||
| step_ok = self.current_wait_step >= self.prefill_wait_step | ||
| need_prefill = self.need_prefill_count > 0 | ||
|
|
||
| if no_decode_reqs or step_ok or need_prefill: | ||
| if need_prefill: | ||
| self.need_prefill_count -= 1 | ||
|
|
||
| self.current_wait_step = 0 | ||
| if prefill_reqs: | ||
| return True | ||
| else: | ||
| return False | ||
| else: | ||
| if prefill_reqs: | ||
| self.current_wait_step += 1 | ||
| return False | ||
|
|
||
| def dp_need_prefill( | ||
| self, | ||
| prefill_reqs: List[InferReq], | ||
| decode_reqs: List[InferReq], | ||
| dp_prefill_req_nums: np.ndarray, | ||
| dp_max_prefill_num: int, | ||
| ) -> bool: | ||
| """ | ||
| dp_need_prefill 接口用于控制 DP 模式下进行chuncked prefill时,需要考虑各个DP的真实运行请求数量: | ||
| 考虑 8 个 dp 的场景,如果每个 dp 执行 prefill 的请求的数量分别为: [1, 1, 0, 0, 0, 0, 0, 0], 则在运行 | ||
| 的过程中,请求数量为0的dp会pad一个fake req来参与计算,但是这会导致这些dp因为一些通信同步的原因,造成大量 | ||
| 算力浪费,实际有效率很低。 | ||
| 解决方法: | ||
| 在判断是否可以进行 prefill 的时候,需要先考虑所有dp的请求数量是否均衡,浪费率是否在可以接受的范围,如果无法 | ||
| 接受这么高的浪费率,则可以延迟 prefill 的执行时机,直到所有dp的浪费率较低时再进行prefill, 不过延迟执行的极限 | ||
| 等待时间,受到 dp_prefill_wait_step 参数的控制。 | ||
| """ | ||
| assert dp_prefill_req_nums.shape[0] == self._global_world_size | ||
|
|
||
| use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0] | ||
| step_ok = self.dp_current_wait_step >= self.dp_prefill_wait_step | ||
|
|
||
| if dp_max_prefill_num > 0 and (use_ratio > 0.6 or step_ok): | ||
| if use_ratio < 0.2: | ||
| self.dp_current_wait_step = 0 | ||
| logger.info(f"dp chuncked prefill effective GPU Utilization Rate {use_ratio}") | ||
|
|
||
| return True | ||
| else: | ||
| if dp_max_prefill_num > 0: | ||
| self.dp_current_wait_step += 1 | ||
| else: | ||
| self.dp_current_wait_step = 0 | ||
| return False |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very long comment. Consider breaking it up into multiple shorter comments, or rephrasing it to be more concise. Long comments can be difficult to read and understand.