diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 678ac8b8f..c481f1ff6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: hooks: - id: flake8 additional_dependencies: [flake8-typing-imports==1.9.0] - args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606'] \ No newline at end of file + args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] \ No newline at end of file diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 15d04b208..14a987f49 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -67,12 +67,6 @@ def is_clear(self): return len(self.reqs) == 0 def merge(self, mini_batch: "Batch"): - for _req in mini_batch.reqs: - self.reqs.append(_req) - self.id_to_reqs = {req.request_id: req for req in self.reqs} - return - - def dp_merge(self, mini_batch: "Batch"): if mini_batch is None: return @@ -81,6 +75,18 @@ def dp_merge(self, mini_batch: "Batch"): self.id_to_reqs = {req.request_id: req for req in self.reqs} return + @staticmethod + def merge_two_batch(batch1: "Batch", batch2: "Batch") -> "Batch": + if batch1 is None and batch2 is None: + return None + + not_none_batch = batch1 if batch1 is not None else batch2 + + merge_batch = Batch(-1, [], not_none_batch.dp_size_in_node) + merge_batch.merge(batch1) + merge_batch.merge(batch2) + return merge_batch + def __repr__(self): return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b452b604c..c10847e3f 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -1,26 +1,20 @@ -import copy import time -import uuid import uvloop import asyncio import torch -import rpyc import pickle -import threading import inspect asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -import concurrent.futures import zmq import zmq.asyncio import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing from typing import Dict, List, Optional -from .batch import Batch +from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue -from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient @@ -79,7 +73,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.eos_id = args.eos_id self.has_wait_tokens = 0 self.max_wait_tokens = args.router_max_wait_tokens - context = zmq.asyncio.Context(2) + context = zmq.Context(2) self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") @@ -106,13 +100,13 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() g_router_lock.obj = self.router_lock - - # 调度和推理进行折叠使用的线程池 - self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) - self.schedule_task = None return async def wait_to_model_ready(self): + # 调度使用的对象 + self.schedule_new_batch: Batch = None + self.schedule_event = asyncio.Event() + # 初始化模型 self.model_rpc_servers = [] # 用于 kv move 管理进程 和 推理进程进行task信息的交互。 @@ -140,8 +134,6 @@ async def wait_to_model_ready(self): self.model_rpc_servers.append(rpc_model) self.model_rpc_client = ModelRpcClient( - model_infer_servers=self.model_rpc_servers, - world_size=self.world_size, rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, ) @@ -223,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes): logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") self.req_queue.extend(req_group) self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - return async def loop_for_fwd( @@ -285,81 +276,39 @@ async def loop_for_fwd( if self.running_batch is None: await asyncio.sleep(0.01) # 10ms - async def get_schedule_result(self, running_batch: Batch): - if self.schedule_task is None: - _start_time = time.time() - - def get_new_batch(): - if time.time() - _start_time < 0.001: - time.sleep(0.003) - - limit_router_queue_length = None - if self.is_multinode_tp: - # 使用 all_reduce 获取最小值 - limit_router_queue_length = len(self.req_queue.waiting_req_list) - limit_router_queue_length_tensor = torch.tensor( - limit_router_queue_length, dtype=torch.int32, device="cpu" - ) - dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group) - limit_router_queue_length = limit_router_queue_length_tensor.item() - - new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length) - return new_batch - - self.schedule_task = asyncio.get_running_loop().run_in_executor(self.overlap_thread_pool, get_new_batch) - return None - else: - result = await self.schedule_task - self.schedule_task = None - return result + def generate_new_batch(self): + limit_router_queue_length = None + if self.is_multinode_tp: + # 使用 all_reduce 获取最小值 + limit_router_queue_length = len(self.req_queue.waiting_req_list) + limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu") + dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group) + limit_router_queue_length = limit_router_queue_length_tensor.item() + + # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。 + new_batch = self.req_queue.generate_new_batch( + Batch.merge_two_batch(self.running_batch, self.schedule_new_batch), limit_router_queue_length + ) + self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) + return async def _step(self): """ 事件处理循环 """ - # 删除所有已经 finished 的 req - # 当前无运行请求时 - if self.running_batch is None: - new_batch: Batch = await self.get_schedule_result(self.running_batch) - if new_batch is not None: - self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs)) - for req in new_batch.reqs: - self.metric_client.histogram_observe( - "lightllm_request_queue_duration_bucket", time.time() - req.start_time - ) - self.stats_tool.count_prompt_tokens(new_batch) - self.running_batch = new_batch - await self._prefill_batch(self.running_batch) - self._filter_runing_batch() - - # 激进调度控制 - if not self.args.disable_aggressive_schedule: - self.has_wait_tokens = self.max_wait_tokens - - elif self.is_multinode_and_multidp: - # 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作, - # 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会 - # padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型 - # 使用的,其ep并行模式下需要所有节点协同。 - await self._decode_batch(self.running_batch) - - return - - # 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。 - if self.has_wait_tokens >= self.max_wait_tokens or self.schedule_task is not None: - new_mini_batch = await self.get_schedule_result(self.running_batch) + # 判断是否有新请求加入推理 + # 激进调度满足,有新的推理batch就需要进行加入。 + # 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。 + if (self.schedule_new_batch is not None) and ( + (not self.args.disable_aggressive_schedule) or (self.has_wait_tokens >= self.max_wait_tokens) + ): + new_batch = self.schedule_new_batch + self.schedule_new_batch = None + self._add_new_batch_to_running_batch(new_batch=new_batch) + await self._prefill_batch(new_batch) + self.stats_tool.count_prompt_tokens(new_batch) + self._filter_reqs_from_running_batch() self.has_wait_tokens = 0 - if new_mini_batch is not None: - - # 激进调度控制 - if not self.args.disable_aggressive_schedule: - self.has_wait_tokens = self.max_wait_tokens - - self.stats_tool.count_prompt_tokens(new_mini_batch) - await self._prefill_batch(new_mini_batch) - if not new_mini_batch.is_clear(): - self.running_batch.merge(new_mini_batch) - return # Check if need pause some requests for decode. for dp_index in range(self.dp_size_in_node): @@ -374,51 +323,46 @@ async def _step(self): # Decode self.stats_tool.count_output_tokens(self.running_batch) - await self._decode_batch(self.running_batch) - self._filter_runing_batch() + await self._decode_batch() + self._filter_reqs_from_running_batch() self.has_wait_tokens += 1 return async def _prefill_batch(self, batch: Batch): - start_time = time.time() - self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill") + # 添加新请求 reqs = [r.to_router_rpc_obj() for r in batch.reqs] await self.model_rpc_client.prefill(reqs) - batch.filter_out_finished_req(self.shm_req_manager) self._send_detokenization_pack() - logger.debug(f"Prefill Batch: {batch.simple_log()} \n") - self.metric_client.histogram_observe( - "lightllm_batch_inference_duration_bucket", time.time() - start_time, "prefill" - ) return - async def _decode_batch(self, batch: Batch): - start_time = time.time() - self.metric_client.counter_inc("lightllm_batch_inference_count", "decode") + async def _decode_batch(self): + self.schedule_event.set() await self.model_rpc_client.decode() - # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。 - if batch is not None: - batch.filter_out_finished_req(self.shm_req_manager) - self._send_detokenization_pack() - self.metric_client.histogram_observe( - "lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode" - ) return - async def _pause_reqs(self, pasue_reqs): + async def _pause_reqs(self, pasue_reqs: List[Req]): pasue_req_ids = [r.request_id for r in pasue_reqs] await self.model_rpc_client.pause_reqs(pasue_req_ids) return - def _filter_runing_batch(self): - if self.running_batch is not None and self.running_batch.is_clear(): - self.running_batch = None - return + def _add_new_batch_to_running_batch(self, new_batch: Batch): + if self.running_batch is None: + self.running_batch = new_batch + else: + self.running_batch.merge(new_batch) + return + + def _filter_reqs_from_running_batch(self): + if self.running_batch is not None: + self.running_batch.filter_out_finished_req(self.shm_req_manager) + if self.running_batch.is_clear(): + self.running_batch = None + return def _can_decode(self, batch: Batch, dp_index: int): - if self.is_pd_run_mode or self.is_safe_schedule: + if self.is_pd_run_mode or self.is_safe_schedule or batch is None: return True return ( batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num @@ -443,12 +387,35 @@ def get_used_tokens(self, dp_index): return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) async def loop_for_netio_req(self): + recv_max_count = 64 + while True: - recv_req: GroupReqIndexes = await self.recv_from_httpserver.recv_pyobj() - if isinstance(recv_req, GroupReqIndexes): - self.add_req(recv_req) - else: - assert False, f"Error Req Inf {recv_req}" + try: + # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 + for _ in range(recv_max_count): + recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GroupReqIndexes): + self.add_req(recv_req) + else: + assert False, f"Error Req Inf {recv_req}" + + # 当队列中存在较多的请求时,将一次接受的数量上调 + recv_max_count = min(int(recv_max_count * 1.3), 256) + + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受的数量下调 + recv_max_count = 64 + + try: + await asyncio.wait_for(self.schedule_event.wait(), timeout=0.02) + except asyncio.TimeoutError: + pass + + if self.schedule_event.is_set(): + self.generate_new_batch() + self.schedule_event.clear() + + return def clean_up(self): return @@ -459,6 +426,13 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi graceful_registry(inspect.currentframe().f_code.co_name) start_parent_check_thread() + def handle_exception(loop, context): + logger.exception(f"Router Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + try: router = RouterManager( args, @@ -467,7 +441,7 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi metric_port=metric_port, ) - asyncio.run(router.wait_to_model_ready()) + loop.run_until_complete(router.wait_to_model_ready()) except: import traceback import sys @@ -480,13 +454,6 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi raise pipe_writer.send("init ok") - - def handle_exception(loop, context): - logger.exception(f"Router Caught exception: {str(context)}") - - loop = asyncio.new_event_loop() - loop.set_exception_handler(handle_exception) - asyncio.set_event_loop(loop) loop.create_task(router.loop_for_fwd()) loop.run_until_complete(router.loop_for_netio_req()) return diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index c1d89beb6..311c2725f 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -69,10 +69,8 @@ def __init__( self.rank_in_node = rank_in_node logger.info(f"Initialized RPC server for rank {self.rank}.") - # 多卡才是跨进程的 - if self.args.tp != 1: - self.loop_thread = threading.Thread(target=self.rpc_loop) - self.loop_thread.start() + self.loop_thread = threading.Thread(target=self.rpc_loop) + self.loop_thread.start() return def rpc_loop(self): @@ -225,17 +223,7 @@ def get_max_total_token_num(self): class ModelRpcClient: - def __init__(self, model_infer_servers: List[ModelRpcServer], world_size, rpc_event, rpc_finished_event): - # model_infer_servers 是传入的推理服务对象,但是在重构后, - # 单卡不使用rpc 通信的时候,里面才有真实对象,当多卡使用rpc - # 以后,model_infer_servers 传入的是 None 数组 - if world_size == 1: - self.model_infer_server: ModelRpcServer = model_infer_servers[0] - else: - self.model_infer_server: ModelRpcServer = None - - self.world_size = world_size - self.use_rpc = self.world_size != 1 + def __init__(self, rpc_event, rpc_finished_event): self.rpc_shm_params = RpcShmParams() self.rpc_shm_params.create_or_link_shm() self.rpc_shm_results = RpcShmResults() @@ -246,65 +234,46 @@ def __init__(self, model_infer_servers: List[ModelRpcServer], world_size, rpc_ev return async def init_model(self, kvargs): - if self.use_rpc: - self.rpc_shm_params.write_func_params("init_model", (kvargs,)) - self.rpc_event.set() + self.rpc_shm_params.write_func_params("init_model", (kvargs,)) + self.rpc_event.set() - self.rpc_finished_event.wait() - self.rpc_finished_event.clear() - return - else: - self.model_infer_server.init_model(kvargs) - return + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + return async def prefill(self, reqs): - if self.use_rpc: - self.rpc_shm_params.write_func_params("prefill", (reqs,)) - self.rpc_event.set() + self.rpc_shm_params.write_func_params("prefill", (reqs,)) + self.rpc_event.set() - await asyncio.to_thread(self.rpc_finished_event.wait) - self.rpc_finished_event.clear() - return - else: - self.model_infer_server.prefill(reqs) - return + await asyncio.to_thread(self.rpc_finished_event.wait) + self.rpc_finished_event.clear() + return async def decode(self): - if self.use_rpc: - self.rpc_shm_params.write_func_params("decode", ()) - self.rpc_event.set() + self.rpc_shm_params.write_func_params("decode", ()) + self.rpc_event.set() - await asyncio.to_thread(self.rpc_finished_event.wait) - self.rpc_finished_event.clear() - return - else: - self.model_infer_server.decode() - return + await asyncio.to_thread(self.rpc_finished_event.wait) + self.rpc_finished_event.clear() + return async def pause_reqs(self, req_ids): - if self.use_rpc: - self.rpc_shm_params.write_func_params("pause_reqs", (req_ids,)) - self.rpc_event.set() + self.rpc_shm_params.write_func_params("pause_reqs", (req_ids,)) + self.rpc_event.set() - self.rpc_finished_event.wait() - self.rpc_finished_event.clear() - return - else: - self.model_infer_server.pause_reqs(req_ids) - return + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + return async def get_max_total_token_num(self): - if self.use_rpc: - self.rpc_shm_params.write_func_params("get_max_total_token_num", ()) - self.rpc_event.set() - - self.rpc_finished_event.wait() - self.rpc_finished_event.clear() - func_name, ret = self.rpc_shm_results.read_func_result() - assert func_name == "get_max_total_token_num" - return ret - else: - return self.model_infer_server.get_max_total_token_num() + self.rpc_shm_params.write_func_params("get_max_total_token_num", ()) + self.rpc_event.set() + + self.rpc_finished_event.wait() + self.rpc_finished_event.clear() + func_name, ret = self.rpc_shm_results.read_func_result() + assert func_name == "get_max_total_token_num" + return ret def _init_env( @@ -352,19 +321,6 @@ async def start_model_process( ): import lightllm.utils.rpyc_fix_utils as _ - # 单卡单机时不使用 rpc - if node_world_size == 1 and args.nnodes == 1: - return ModelRpcServer( - args, - rank, - rank_in_node, - node_world_size, - rpc_event, - rpc_finished_event, - info_queue, - mem_queue, - ) - success_event = mp.Event() proc = mp.Process( target=_init_env, diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 804d654d5..2685096de 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -76,6 +76,9 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + if len(self.waiting_req_list) == 0: + return None + # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 8d90da5b3..0222aa749 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -57,6 +57,8 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + if len(self.waiting_req_list) == 0: + return None # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_prefill.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_prefill.py index ad1af268d..241ab9148 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_prefill.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_prefill.py @@ -59,6 +59,8 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + if len(self.waiting_req_list) == 0: + return None # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) diff --git a/lightllm/server/router/req_queue/continues_batch/impl.py b/lightllm/server/router/req_queue/continues_batch/impl.py index 413481853..ce7811f4b 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl.py +++ b/lightllm/server/router/req_queue/continues_batch/impl.py @@ -60,6 +60,9 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + if len(self.waiting_req_list) == 0: + return None + # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size diff --git a/lightllm/server/router/req_queue/continues_batch/impl_for_pd_decode.py b/lightllm/server/router/req_queue/continues_batch/impl_for_pd_decode.py index ed54e2af3..0c1943e35 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/continues_batch/impl_for_pd_decode.py @@ -25,6 +25,9 @@ def _init_cache_list(self, current_batch: Batch, is_busy): # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch, limit_router_queue_length: int = None): + if len(self.waiting_req_list) == 0: + return None + # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index b8df6e3e9..dc07ff2b8 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -43,7 +43,7 @@ def _merge_batch(self, dp_batches: List[Batch]): merged_batch: Batch = None for iter_batch in dp_batches: if merged_batch is not None: - merged_batch.dp_merge(iter_batch) + merged_batch.merge(iter_batch) else: merged_batch = iter_batch return merged_batch diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 7d64a371e..d50c4e7ca 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,10 +1,11 @@ import time from lightllm.utils.log_utils import init_logger +from .batch import Batch logger = init_logger(__name__) -class Stats: +class Stats: def __init__(self, log_status, log_stats_interval) -> None: self.log_stats = log_status self.log_stats_interval = log_stats_interval @@ -13,16 +14,16 @@ def __init__(self, log_status, log_stats_interval) -> None: self.output_tokens = 0 self.prompt_tokens = 0 return - - def count_prompt_tokens(self, run_batch): - if self.log_stats: + + def count_prompt_tokens(self, run_batch: Batch): + if self.log_stats and run_batch is not None: tokens = run_batch.input_tokens() self.prompt_tokens += tokens self.all_tokens += tokens return - - def count_output_tokens(self, run_batch): - if self.log_stats: + + def count_output_tokens(self, run_batch: Batch): + if self.log_stats and run_batch is not None: tokens = len(run_batch.reqs) self.output_tokens += tokens self.all_tokens += tokens @@ -34,13 +35,13 @@ def print_stats(self): now = time.time() if now - self.last_log_time > self.log_stats_interval: - logger.debug(f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s") + logger.debug( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) self.all_tokens = 0 self.output_tokens = 0 self.prompt_tokens = 0 self.last_log_time = now return - - \ No newline at end of file