diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 601b2a48a..3f3eaf96f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -198,6 +198,14 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""aggressive schedule can lead to frequent prefill interruptions during decode. disabling it allows the router_max_wait_tokens parameter to work more effectively.""", ) + parser.add_argument( + "--dp_prefill_wait_step", + type=int, + default=0, + help="""dp_prefill_wait_step is used to control the pacing of dp chunked prefill mode, aiming to reduce + computational waste during prefill. However, higher values can negatively impact the + first token latency. It is generally recommended to set this value between 0 and 12.""", + ) parser.add_argument( "--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use." diff --git a/lightllm/server/core/objs/out_token_circlequeue.py b/lightllm/server/core/objs/out_token_circlequeue.py index 6bcee16ce..b3cb65ef5 100644 --- a/lightllm/server/core/objs/out_token_circlequeue.py +++ b/lightllm/server/core/objs/out_token_circlequeue.py @@ -3,7 +3,7 @@ from typing import Tuple LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 128)) -LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 6)) +LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8)) class QueueItem(ctypes.Structure): diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 7bf5ee239..4e1971ef9 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -283,7 +283,14 @@ class ChunkedPrefillReq(Req): def get_tuple_tokens(self, is_busy, router_max_new_token_len): args = get_env_start_args() - max_waiting_token = args.router_max_wait_tokens + # chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于 + # 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存 + # 占用量的过程中,分chuncked 需要考虑其因为分 chuncked带来的生命期的延长,具体 + # 体现就是在 b_len 的计算中,xxx * (max_waiting_token + 1) 的部分,这部分 + # 就是通过模拟加长其输出token长度,来延长其在估计阶段的生命周期。max_waiting_token + # 的计算是保守的,每次chuncked prefill 延迟的最大步数为两种模式之合,因为 + # 这个并不会导致预估的token占用量大幅增加,所以可以放心使用。 + max_waiting_token = args.router_max_wait_tokens + args.dp_prefill_wait_step has_out_len = self.shm_cur_output_len if self.sample_params.ignore_eos: cur_max_new_token_len = self.sample_params.max_new_tokens diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 8a43d983d..ec1eb427e 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -43,6 +43,7 @@ class StartArgs: router_token_ratio: float = field(default=0.0) router_max_new_token_len: int = field(default=1024) router_max_wait_tokens: int = field(default=6) + dp_prefill_wait_step: int = field(default=0) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 8def96733..fa455c225 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -22,6 +22,7 @@ from .async_queue import AsyncQueue from lightllm.server.core.objs import Req, FinishStatus from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem @@ -281,8 +282,12 @@ async def generate( alloced_req_indexes = [] while len(alloced_req_indexes) < sampling_params.n: alloc_req_index = await self.shm_req_manager.async_alloc_req_index() + sleep_time = 0.1 while alloc_req_index is None: - await asyncio.sleep(0.1) + await asyncio.sleep(sleep_time) + sleep_time *= 1.1 + sleep_time = min(1, sleep_time) + alloc_req_index = await self.shm_req_manager.async_alloc_req_index() alloced_req_indexes.append(alloc_req_index) req_objs = [] @@ -648,31 +653,38 @@ async def handle_loop(self): token_list = [] for req in req_status.group_req_objs.shm_req_objs: req_id = req.request_id - if not req.out_tokens_queue.is_empty(): - - text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() - req.cumlogprob += float(req.shm_logprobs.arr[src_index]) - metadata = { - "id": int(req.shm_prompt_ids.arr[src_index]), - "logprob": float(req.shm_logprobs.arr[src_index]), - "cumlogprob": float(req.cumlogprob) / count_output_tokens, - "special": special, - "count_output_tokens": count_output_tokens, - "prompt_cache_len": req.prompt_cache_len, - "mtp_accepted_token_num": req.mtp_accepted_token_num, - } - if self.args.return_all_prompt_logprobs: - metadata.update(req.get_all_prompt_metadata()) - if self.args.use_reward_model: - metadata["score"] = float(req.reward_score) - - req.out_tokens_queue.pop_no_ret() - - if req.finish_token_index != src_index: - token_list.append((req_id, text, metadata, FinishStatus())) + read_token_count = 1 + if req.out_tokens_queue.is_full(): + read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE + + for _ in range(read_token_count): + if not req.out_tokens_queue.is_empty(): + + text, src_index, special, count_output_tokens = req.out_tokens_queue.peek() + req.cumlogprob += float(req.shm_logprobs.arr[src_index]) + metadata = { + "id": int(req.shm_prompt_ids.arr[src_index]), + "logprob": float(req.shm_logprobs.arr[src_index]), + "cumlogprob": float(req.cumlogprob) / count_output_tokens, + "special": special, + "count_output_tokens": count_output_tokens, + "prompt_cache_len": req.prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, + } + if self.args.return_all_prompt_logprobs: + metadata.update(req.get_all_prompt_metadata()) + if self.args.use_reward_model: + metadata["score"] = float(req.reward_score) + + req.out_tokens_queue.pop_no_ret() + + if req.finish_token_index != src_index: + token_list.append((req_id, text, metadata, FinishStatus())) + else: + finish_status = FinishStatus(req.finish_status.status) + token_list.append((req_id, text, metadata, finish_status)) else: - finish_status = FinishStatus(req.finish_status.status) - token_list.append((req_id, text, metadata, finish_status)) + break async with req_status.lock: req_status.out_token_info_list.extend(token_list) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 45e82ff3d..b452b604c 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -110,7 +110,6 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por # 调度和推理进行折叠使用的线程池 self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.schedule_task = None - self.overlap_event = threading.Event() return async def wait_to_model_ready(self): @@ -288,8 +287,12 @@ async def loop_for_fwd( async def get_schedule_result(self, running_batch: Batch): if self.schedule_task is None: + _start_time = time.time() def get_new_batch(): + if time.time() - _start_time < 0.001: + time.sleep(0.003) + limit_router_queue_length = None if self.is_multinode_tp: # 使用 all_reduce 获取最小值 @@ -300,9 +303,6 @@ def get_new_batch(): dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group) limit_router_queue_length = limit_router_queue_length_tensor.item() - self.overlap_event.wait(timeout=0.020) - self.overlap_event.clear() - time.sleep(0.003) new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length) return new_batch @@ -320,7 +320,7 @@ async def _step(self): # 删除所有已经 finished 的 req # 当前无运行请求时 if self.running_batch is None: - new_batch = await self.get_schedule_result(self.running_batch) + new_batch: Batch = await self.get_schedule_result(self.running_batch) if new_batch is not None: self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs)) for req in new_batch.reqs: @@ -383,7 +383,6 @@ async def _prefill_batch(self, batch: Batch): start_time = time.time() self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill") reqs = [r.to_router_rpc_obj() for r in batch.reqs] - self.overlap_event.set() await self.model_rpc_client.prefill(reqs) batch.filter_out_finished_req(self.shm_req_manager) self._send_detokenization_pack() @@ -397,7 +396,6 @@ async def _prefill_batch(self, batch: Batch): async def _decode_batch(self, batch: Batch): start_time = time.time() self.metric_client.counter_inc("lightllm_batch_inference_count", "decode") - self.overlap_event.set() await self.model_rpc_client.decode() # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。 if batch is not None: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 329dc9f3b..dd1ea45fe 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -1,18 +1,13 @@ import os -import asyncio import numpy as np -import rpyc import torch -import socket -from datetime import timedelta -from typing import Dict, List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams +from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -26,12 +21,18 @@ from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size from lightllm.utils.dist_utils import get_dp_rank_in_node from lightllm.distributed import dist_group_manager +from .chuncked_prefill_state import ChunkedPrefillState import torch.distributed as dist class ModeBackend: def __init__(self) -> None: self.shm_req_manager = ShmReqManager() + + # 当子类处于chuncked prefill 相关模式时,会使用该管理变量进行一些 chuncked prefill + # 的推理控制,具体使用方式可以参考 ChunkedPrefillBackend 类中的使用方式。如果是非 + # chuncked prefill 相关的模式,该状态变量不会生效。 + self.chunked_prefill_state = ChunkedPrefillState() pass def init_model(self, kvargs): @@ -56,7 +57,6 @@ def init_model(self, kvargs): self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph - self.cache = {} self.logger = init_logger(__name__) self.weight_dir = kvargs["weight_dir"] @@ -140,6 +140,14 @@ def init_model(self, kvargs): vocab_size=self.model.vocab_size, ) + # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 + if self.dp_size > 1: + self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + self.dp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + self.dp_all_gather_tensor = torch.tensor( + [0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False + ) + self.init_custom() return @@ -446,6 +454,27 @@ def _update_reqs_mtp_gen_token_ids(self, reqs: List[InferReq], mtp_draft_next_to req.mtp_gen_token_ids.append(token_id) return + def _dp_all_gather_prefill_req_num(self, prefill_reqs: List[InferReq]) -> Tuple[np.ndarray, int]: + """ + Gather the number of prefill requests across all DP ranks. + """ + current_dp_prefill_num = len(prefill_reqs) + self.dp_gather_item_tensor.fill_(current_dp_prefill_num) + dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) + dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy() + max_prefill_num = np.max(dp_prefill_req_nums) + return dp_prefill_req_nums, max_prefill_num + + def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int: + """ + Reduce the number of decode requests across all DP ranks. + """ + current_dp_decode_num = len(decode_reqs) + self.dp_reduce_tensor.fill_(current_dp_decode_num) + dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_decode_num = self.dp_reduce_tensor.item() + return max_decode_num + def preload_prompt_cache_kv_buffer(self, model_cfg): self.logger.info("Preload prompt cache kv buffer.") cur_rank = dist.get_rank() diff --git a/lightllm/server/router/model_infer/mode_backend/chuncked_prefill_state.py b/lightllm/server/router/model_infer/mode_backend/chuncked_prefill_state.py new file mode 100644 index 000000000..7a1ca7b0c --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/chuncked_prefill_state.py @@ -0,0 +1,89 @@ +import dataclasses +import numpy as np +from typing import List +from lightllm.utils.envs_utils import get_env_start_args +from ..infer_batch import InferReq +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ChunkedPrefillState: + """ + 用于保存和控制 chuncked prefill 推理调度的控制状态,因为不同的场景,对于首字和包间的 + 诉求不是特别一致,所以需要一个状态来控制。主要通过 args.router_max_wait_tokens 参数 + 可以调节 prefill 的激进程度和方式,来协调首字和包间的平衡。 + """ + + prefill_wait_step: int = 0 + need_prefill_count: int = 0 + current_wait_step: int = 0 + + # dp chuncked prefill 的等待步数参数 + dp_prefill_wait_step: int = 0 + dp_current_wait_step: int = 0 + + # world_size + _global_world_size: int = 0 + + def __post_init__(self): + args = get_env_start_args() + self.prefill_wait_step = args.router_max_wait_tokens + self.dp_prefill_wait_step = args.dp_prefill_wait_step + self._global_world_size = args.tp + return + + def need_prefill(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]) -> bool: + no_decode_reqs = len(decode_reqs) == 0 + step_ok = self.current_wait_step >= self.prefill_wait_step + need_prefill = self.need_prefill_count > 0 + + if no_decode_reqs or step_ok or need_prefill: + if need_prefill: + self.need_prefill_count -= 1 + + self.current_wait_step = 0 + if prefill_reqs: + return True + else: + return False + else: + if prefill_reqs: + self.current_wait_step += 1 + return False + + def dp_need_prefill( + self, + prefill_reqs: List[InferReq], + decode_reqs: List[InferReq], + dp_prefill_req_nums: np.ndarray, + dp_max_prefill_num: int, + ) -> bool: + """ + dp_need_prefill 接口用于控制 DP 模式下进行chuncked prefill时,需要考虑各个DP的真实运行请求数量: + 考虑 8 个 dp 的场景,如果每个 dp 执行 prefill 的请求的数量分别为: [1, 1, 0, 0, 0, 0, 0, 0], 则在运行 + 的过程中,请求数量为0的dp会pad一个fake req来参与计算,但是这会导致这些dp因为一些通信同步的原因,造成大量 + 算力浪费,实际有效率很低。 + 解决方法: + 在判断是否可以进行 prefill 的时候,需要先考虑所有dp的请求数量是否均衡,浪费率是否在可以接受的范围,如果无法 + 接受这么高的浪费率,则可以延迟 prefill 的执行时机,直到所有dp的浪费率较低时再进行prefill, 不过延迟执行的极限 + 等待时间,受到 dp_prefill_wait_step 参数的控制。 + """ + assert dp_prefill_req_nums.shape[0] == self._global_world_size + + use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0] + step_ok = self.dp_current_wait_step >= self.dp_prefill_wait_step + + if dp_max_prefill_num > 0 and (use_ratio > 0.6 or step_ok): + if use_ratio < 0.2: + self.dp_current_wait_step = 0 + logger.info(f"dp chuncked prefill effective GPU Utilization Rate {use_ratio}") + + return True + else: + if dp_max_prefill_num > 0: + self.dp_current_wait_step += 1 + else: + self.dp_current_wait_step = 0 + return False diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 00528fec7..d5d621089 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -1,15 +1,8 @@ -import torch from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.infer_batch import g_infer_context -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, - prepare_decode_inputs, -) logger = init_logger(__name__) @@ -18,14 +11,10 @@ class ChunkedPrefillBackend(ModeBackend): def __init__(self) -> None: super().__init__() - self.forward_step = 0 - args = get_env_start_args() - self.max_wait_step = args.router_max_wait_tokens - self.need_prefill_count = 0 def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs, init_req_obj=False) - self.need_prefill_count += 1 + self.chunked_prefill_state.need_prefill_count += 1 return def decode(self): @@ -38,38 +27,15 @@ def decode(self): # 先 decode if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ContinuesBatchBackend.normal_decode( + self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs ) - next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) - del model_output # 再 prefill - if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): - if prefill_reqs: - self.need_prefill_count -= 1 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - model_output = self.model.forward(model_input) - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) - del model_output + if self.chunked_prefill_state.need_prefill(prefill_reqs=prefill_reqs, decode_reqs=decode_reqs): + ContinuesBatchBackend.normal_prefill_reqs( + self, prefill_reqs=prefill_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self.forward_step += 1 return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py index b083a7263..4eb7b55c7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py @@ -1,14 +1,9 @@ import os -import shutil import torch from .impl import ChunkedPrefillBackend -from typing import List, Tuple +from typing import List from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, - prepare_decode_inputs, -) -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -41,46 +36,25 @@ def decode(self): # 先 decode if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ContinuesBatchBackend.normal_decode( + self, + decode_reqs=decode_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._mask_first_gen_token_logits, ) - self._mask_first_gen_token_logits(run_reqs, logits) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) - del model_output - del logits # 再 prefill - if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): - if prefill_reqs: - self.need_prefill_count -= 1 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - self._mask_first_gen_token_logits(run_reqs, logits) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) - del model_output - del logits + if self.chunked_prefill_state.need_prefill(prefill_reqs=prefill_reqs, decode_reqs=decode_reqs): + ContinuesBatchBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._mask_first_gen_token_logits, + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self.forward_step += 1 return def _mask_first_gen_token_logits(self, run_reqs: List[InferReq], logits: torch.Tensor): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index 40401b895..541f2384e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -6,11 +6,7 @@ from .impl import ChunkedPrefillBackend from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, - prepare_decode_inputs, -) -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from lightllm.server.tokenizer import get_tokenizer from typing import List, Tuple from lightllm.utils.log_utils import init_logger @@ -66,71 +62,47 @@ def decode(self): # 先 decode if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - self._init_guide_infos(run_reqs) - all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) - if not all_has_no_constraint: - mask = torch.ones_like(logits, dtype=torch.bool) - for i, run_obj in enumerate(run_reqs): - self._mask_req_out_token(i, run_obj, mask) - logits[mask] = -1000000.0 - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=False, - do_filter_finished_reqs=False, + ContinuesBatchBackend.normal_decode( + self, + decode_reqs=decode_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._decode_mask_callback, extra_post_req_handle_func=self._update_state_fsm, ) - del model_output - del logits # 再 prefill - if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): - if prefill_reqs: - self.need_prefill_count -= 1 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0 - self._init_guide_infos(run_reqs) - mask = torch.ones_like(logits, dtype=torch.bool) - for i, run_obj in enumerate(run_reqs): - self._mask_req_out_token(i, run_obj, mask) - - logits[mask] = -1000000.0 - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=True, - do_filter_finished_reqs=False, - extra_post_req_handle_func=self._update_state_fsm, - ) - del model_output - del logits + if self.chunked_prefill_state.need_prefill(prefill_reqs=prefill_reqs, decode_reqs=decode_reqs): + ContinuesBatchBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._prefill_mask_callback, + extra_post_req_handle_func=self._update_state_fsm, + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self.forward_step += 1 + return + + def _decode_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + self._init_guide_infos(run_reqs) + all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) + if not all_has_no_constraint: + mask = torch.ones_like(logits, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, mask) + logits[mask] = -1000000.0 + return + + def _prefill_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0 + self._init_guide_infos(run_reqs) + mask = torch.ones_like(logits, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, mask) + + logits[mask] = -1000000.0 return def _update_state_fsm(self, req_obj: InferReq, next_token_id, next_token_logprob): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index 701333116..0770bc485 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -1,12 +1,8 @@ import torch from .impl import ChunkedPrefillBackend -from typing import List, Tuple +from typing import List +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, - prepare_decode_inputs, -) -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -50,77 +46,48 @@ def decode(self): # 先 decode if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - self._init_prefix_infos(run_reqs=run_reqs) - - all_no_prefix = all([len(e.prefix_str) == 0 for e in run_reqs]) - if not all_no_prefix: - mask = torch.ones_like(logits, dtype=torch.bool) - for i, run_obj in enumerate(run_reqs): - self._mask_decode_not_prefix_token(i, run_obj, mask) - - logits[mask] = -1000000.0 - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=False, - do_filter_finished_reqs=False, + ContinuesBatchBackend.normal_decode( + self, + decode_reqs=decode_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._decode_mask_callback, extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, ) - del model_output - del logits # 再 prefill - if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): - if prefill_reqs: - self.need_prefill_count -= 1 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0 - self._init_prefix_infos(run_reqs=run_reqs) - mask = torch.ones_like(logits, dtype=torch.bool) - for i, run_obj in enumerate(run_reqs): - self._mask_not_prefix_token(i, run_obj, mask) - logits[mask] = -1000000.0 - - # 有prefix - self._topk_repair(run_reqs) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - self._topk_recover(run_reqs) - - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=True, - do_filter_finished_reqs=False, - extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, - ) - del model_output - del logits + if self.chunked_prefill_state.need_prefill(prefill_reqs=prefill_reqs, decode_reqs=decode_reqs): + ContinuesBatchBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._prefill_mask_callback, + extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self.forward_step += 1 + return + + def _decode_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + self._init_prefix_infos(run_reqs=run_reqs) + + all_no_prefix = all([len(e.prefix_str) == 0 for e in run_reqs]) + if not all_no_prefix: + mask = torch.ones_like(logits, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + self._mask_decode_not_prefix_token(i, run_obj, mask) + + logits[mask] = -1000000.0 + return + + def _prefill_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0 + self._init_prefix_infos(run_reqs=run_reqs) + mask = torch.ones_like(logits, dtype=torch.bool) + for i, run_obj in enumerate(run_reqs): + self._mask_not_prefix_token(i, run_obj, mask) + logits[mask] = -1000000.0 return def _update_tokenhealing_req_prefix_str(self, req_obj: InferReq, next_token_id, next_token_logprob): @@ -174,20 +141,6 @@ def _mask_decode_not_prefix_token(self, i, run_obj: InferReq, mask): mask[i, :] = False return - def _topk_repair(self, run_reqs: list[InferReq]): - for req_obj in run_reqs: - if len(req_obj.prefix_str) != 0: - req_obj.origin_topk = req_obj.sampling_param.shm_param.top_k - req_obj.sampling_param.shm_param.top_k = 1 - else: - req_obj.origin_topk = req_obj.sampling_param.shm_param.top_k - return - - def _topk_recover(self, run_reqs: list[InferReq]): - for req_obj in run_reqs: - req_obj.sampling_param.shm_param.top_k = req_obj.origin_topk - return - def _init_prefix_infos(self, run_reqs: List[InferReq]): for i, run_obj in enumerate(run_reqs): if not hasattr(run_obj, "prefix_str"): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 72172431f..8039c452a 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -1,17 +1,11 @@ -import copy import functools import torch -from typing import List, Tuple - +from typing import List from .impl import ChunkedPrefillBackend -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, - prepare_decode_inputs, -) from lightllm.utils.infer_utils import calculate_time -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -33,10 +27,6 @@ def init_custom(self): self.xgrammar_compiler = xgr.GrammarCompiler(self.tokenizer_info, max_threads=8) self.xgrammar_token_bitmask = xgr.allocate_token_bitmask(1, self.tokenizer_info.vocab_size) - eos_token_ids = [] - eos_token_ids.append(self.tokenizer.eos_token_id) - eos_token_ids.extend(self.args.eos_id) - @functools.lru_cache(maxsize=200) def get_cached_grammar(type: str, grammar: str): logger.info(f"grammar cache miss for {type}: '{grammar}'") @@ -66,75 +56,50 @@ def decode(self): # 先 decode if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - self._init_req_xgrammer_matcher_infos(run_reqs=run_reqs) - all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) - if not all_has_no_constraint: - for i, run_obj in enumerate(run_reqs): - self._mask_req_out_token(i, run_obj, logits[i]) - - logits[logits == float("-inf")] = -1000000.0 - # mask out the padding token logits - logits[:, self.tokenizer_info.vocab_size :] = -1000000.0 - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=False, - do_filter_finished_reqs=False, + ContinuesBatchBackend.normal_decode( + self, + decode_reqs=decode_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._decode_mask_callback, extra_post_req_handle_func=self._update_xgrammer_fsm, ) - del model_output - del logits # 再 prefill - if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): - if prefill_reqs: - self.need_prefill_count -= 1 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - self._init_req_xgrammer_matcher_infos(run_reqs=run_reqs) - for i, run_obj in enumerate(run_reqs): - self._mask_req_out_token(i, run_obj, logits[i]) - - # fix the logics with -inf to a large negative value - logits[logits == float("-inf")] = -1000000.0 - # mask out the padding token logits - logits[:, self.tokenizer_info.vocab_size :] = -1000000.0 - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, - next_token_ids, - next_token_logprobs, - is_chuncked_mode=True, - do_filter_finished_reqs=False, - extra_post_req_handle_func=self._update_xgrammer_fsm, - ) - del model_output - del logits + if self.chunked_prefill_state.need_prefill(prefill_reqs=prefill_reqs, decode_reqs=decode_reqs): + ContinuesBatchBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + mask_func=self._prefill_mask_callback, + extra_post_req_handle_func=self._update_xgrammer_fsm, + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self.forward_step += 1 + return + + def _decode_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + self._init_req_xgrammer_matcher_infos(run_reqs=run_reqs) + all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) + if not all_has_no_constraint: + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, logits[i]) + + logits[logits == float("-inf")] = -1000000.0 + # mask out the padding token logits + logits[:, self.tokenizer_info.vocab_size :] = -1000000.0 + return + + def _prefill_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): + self._init_req_xgrammer_matcher_infos(run_reqs=run_reqs) + for i, run_obj in enumerate(run_reqs): + self._mask_req_out_token(i, run_obj, logits[i]) + + # fix the logics with -inf to a large negative value + logits[logits == float("-inf")] = -1000000.0 + # mask out the padding token logits + logits[:, self.tokenizer_info.vocab_size :] = -1000000.0 return def _update_xgrammer_fsm(self, req_obj: InferReq, next_token_id, next_token_logprob): @@ -149,7 +114,7 @@ def _update_xgrammer_fsm(self, req_obj: InferReq, next_token_id, next_token_logp req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP) return - def _mask_req_out_token(self, i, run_obj: InferReq, logits): + def _mask_req_out_token(self, i, run_obj: InferReq, logits: torch.Tensor): import xgrammar as xgr if run_obj.get_chuncked_input_token_len() == run_obj.get_cur_total_len(): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index b0eb2b58f..184fc7a1c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -1,14 +1,13 @@ import torch -from typing import List, Tuple +from typing import List, Tuple, Callable, Optional from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -30,40 +29,76 @@ def decode(self): g_infer_context.filter_reqs(aborted_reqs) if prefill_reqs: - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + self.normal_prefill_reqs( + prefill_reqs=prefill_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs ) - model_output = self.model.forward(model_input) - logits = model_output.logits - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) + if decode_reqs: + self.normal_decode(decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) + def normal_prefill_reqs( + self, + prefill_reqs: List[InferReq], + uninit_reqs: List[InferReq], + ok_finished_reqs: List[InferReq], + mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ): + model_input, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal + ) + model_output = self.model.forward(model_input) + logits = model_output.logits - if decode_reqs: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) + if mask_func is not None: + mask_func(run_reqs, logits) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) + self._post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=not self.disable_chunked_prefill, + do_filter_finished_reqs=False, + extra_post_req_handle_func=extra_post_req_handle_func, + ) + return + + def normal_decode( + self, + decode_reqs: List[InferReq], + uninit_reqs: List[InferReq], + ok_finished_reqs: List[InferReq], + mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ): + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + if mask_func is not None: + mask_func(run_reqs, logits) + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=False, + do_filter_finished_reqs=False, + extra_post_req_handle_func=extra_post_req_handle_func, + ) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index 81d3de6a2..6fc4bb206 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -4,18 +4,13 @@ import torch.distributed as dist import threading from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from typing import List, Tuple -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.core.objs import FinishStatus -from lightllm.server.pd_io_struct import UpKVStatus from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.pre import prepare_decode_inputs -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from .up_status import UpStatusManager from rpyc.utils.server import ThreadedServer -from lightllm.common.basemodel.infer_lock import g_infer_state_lock, g_router_lock +from lightllm.common.basemodel.infer_lock import g_router_lock from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.envs_utils import get_unique_server_name @@ -68,21 +63,8 @@ def decode(self): self._filter_reqs(aborted_reqs) if decode_reqs: - - model_input, run_reqs = prepare_decode_inputs(decode_reqs) - model_output = self.model.forward(model_input) - logits = model_output.logits - - self._overlap_req_init_and_filter( - uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True - ) - - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False + ContinuesBatchBackend.normal_decode( + self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py index dd6afb034..43b57abf3 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py @@ -18,11 +18,6 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap return - def init_custom(self): - super().init_custom() - self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - return - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs, init_req_obj=False) return @@ -35,9 +30,7 @@ def decode(self): self._filter_reqs(aborted_reqs) - self.reduce_tensor.fill_(len(decode_reqs)) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_decode_num = self.reduce_tensor.item() + max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs) if max_decode_num != 0: if not self.enable_decode_microbatch_overlap: DPChunkedPrefillBackend.normal_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py index 61b7e07f6..44ab26321 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py @@ -26,9 +26,7 @@ def decode(self): self._filter_reqs(aborted_reqs) - self.reduce_tensor.fill_(len(decode_reqs)) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_decode_num = self.reduce_tensor.item() + max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs) if max_decode_num != 0: if not self.enable_decode_microbatch_overlap: DPChunkedPrefillWithMTPBackend.normal_mtp_decode( diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index a8084f2a1..6f1693479 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -6,15 +6,11 @@ import torch.distributed as dist from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.core.objs import FinishStatus +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend from rpyc.utils.server import ThreadedServer from .prefill_task_cache import g_kv_move_task_cache from lightllm.utils.device_utils import kv_trans_use_p2p @@ -70,20 +66,11 @@ def decode(self): if ok_finished_reqs: self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) self._filter_reqs(ok_finished_reqs) + ok_finished_reqs.clear() if prefill_reqs: - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal - ) - - model_output = self.model.forward(model_input) - logits = model_output.logits - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False + ContinuesBatchBackend.normal_prefill_reqs( + self, prefill_reqs=prefill_reqs, uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs ) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py index 84777711d..df324b317 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py @@ -1,10 +1,8 @@ -import torch import torch.multiprocessing as mp -import torch.distributed as dist from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context +from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args from .prefill_impl import ChunckedPrefillForPrefillNode from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend @@ -16,11 +14,6 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: super().__init__(info_queue=info_queue, mem_queue=mem_queue) self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap - def init_custom(self): - super().init_custom() - self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - return - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs) return @@ -41,11 +34,8 @@ def decode(self): ok_finished_reqs.clear() # 进行 chuncked prefill - current_dp_prefill_num = len(prefill_reqs) - self.reduce_tensor.fill_(current_dp_prefill_num) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_prefill_num = self.reduce_tensor.item() - if max_prefill_num != 0: + dp_prefill_req_nums, max_prefill_num = self._dp_all_gather_prefill_req_num(prefill_reqs=prefill_reqs) + if self.chunked_prefill_state.dp_need_prefill(prefill_reqs, decode_reqs, dp_prefill_req_nums, max_prefill_num): if not self.enable_prefill_microbatch_overlap: DPChunkedPrefillBackend.normal_prefill_reqs( self, prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py index 3f53070f6..2cdb385c9 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py @@ -1,5 +1,4 @@ import torch.multiprocessing as mp -import torch.distributed as dist from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.log_utils import init_logger from .prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode @@ -34,12 +33,8 @@ def decode(self): ok_finished_reqs.clear() # 进行 chuncked prefill - current_dp_prefill_num = len(prefill_reqs) - self.reduce_tensor.fill_(current_dp_prefill_num) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_prefill_num = self.reduce_tensor.item() - - if max_prefill_num != 0: + dp_prefill_req_nums, max_prefill_num = self._dp_all_gather_prefill_req_num(prefill_reqs=prefill_reqs) + if self.chunked_prefill_state.dp_need_prefill(prefill_reqs, decode_reqs, dp_prefill_req_nums, max_prefill_num): if not self.enable_prefill_microbatch_overlap: DPChunkedPrefillWithMTPBackend.normal_mtp_prefill_reqs( self, diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index 376ed1501..e575cab14 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -1,14 +1,9 @@ import torch -import torch.distributed as dist -import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from lightllm.common.basemodel.batch_objs import ModelOutput -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams -from lightllm.server.core.objs import FinishStatus -from lightllm.utils.log_utils import init_logger + +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs @@ -28,10 +23,6 @@ def __init__(self) -> None: self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap pass - def init_custom(self): - self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - return - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs, init_req_obj=False) return @@ -44,19 +35,14 @@ def decode(self): if aborted_reqs: g_infer_context.filter_reqs(aborted_reqs) - current_dp_prefill_num = len(prefill_reqs) - self.reduce_tensor.fill_(current_dp_prefill_num) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_prefill_num = self.reduce_tensor.item() - if max_prefill_num != 0: + dp_prefill_req_nums, max_prefill_num = self._dp_all_gather_prefill_req_num(prefill_reqs=prefill_reqs) + if self.chunked_prefill_state.dp_need_prefill(prefill_reqs, decode_reqs, dp_prefill_req_nums, max_prefill_num): if not self.enable_prefill_microbatch_overlap: self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) else: self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) - self.reduce_tensor.fill_(len(decode_reqs)) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_decode_num = self.reduce_tensor.item() + max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs) if max_decode_num != 0: if not self.enable_decode_microbatch_overlap: self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py index 0b8294f46..a67ba5b4e 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py @@ -1,5 +1,4 @@ import torch -import torch.distributed as dist import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq @@ -24,10 +23,6 @@ def __init__(self) -> None: self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap pass - def init_custom(self): - self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - return - def prefill(self, reqs: List[Tuple]): self._init_reqs(reqs, init_req_obj=False) return @@ -40,19 +35,15 @@ def decode(self): if aborted_reqs: g_infer_context.filter_reqs(aborted_reqs) - current_dp_prefill_num = len(prefill_reqs) - self.reduce_tensor.fill_(current_dp_prefill_num) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_prefill_num = self.reduce_tensor.item() - if max_prefill_num != 0: + dp_prefill_req_nums, max_prefill_num = self._dp_all_gather_prefill_req_num(prefill_reqs=prefill_reqs) + + if self.chunked_prefill_state.dp_need_prefill(prefill_reqs, decode_reqs, dp_prefill_req_nums, max_prefill_num): if not self.enable_prefill_microbatch_overlap: self.normal_mtp_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) else: self.overlap_mtp_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) - self.reduce_tensor.fill_(len(decode_reqs)) - dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) - max_decode_num = self.reduce_tensor.item() + max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs) if max_decode_num != 0: if not self.enable_decode_microbatch_overlap: self.normal_mtp_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)