diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 2ceef2e2a9..aef919c864 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -92,6 +92,9 @@ class LLMConfig(YamlModel): top_logprobs: Optional[int] = None timeout: int = 600 context_length: Optional[int] = None # Max input tokens + # For rate limit control + rpm: Optional[int] = 0 + tpm: Optional[int] = 0 # For Amazon Bedrock region_name: str = None diff --git a/metagpt/ext/aflow/scripts/optimizer.py b/metagpt/ext/aflow/scripts/optimizer.py index 0ac4827e71..ac5d7c9d1a 100644 --- a/metagpt/ext/aflow/scripts/optimizer.py +++ b/metagpt/ext/aflow/scripts/optimizer.py @@ -5,7 +5,7 @@ import asyncio import time -from typing import List, Literal +from typing import List, Literal, Optional from pydantic import BaseModel, Field @@ -24,9 +24,9 @@ class GraphOptimize(BaseModel): - modification: str = Field(default="", description="modification") - graph: str = Field(default="", description="graph") - prompt: str = Field(default="", description="prompt") + modification: Optional[str] = Field(default="", description="modification") + graph: Optional[str] = Field(default="", description="graph") + prompt: Optional[str] = Field(default="", description="prompt") class Optimizer: @@ -90,7 +90,7 @@ def optimize(self, mode: OptimizerType = "Graph"): break except Exception as e: retry_count += 1 - logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})") + logger.exception(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})") if retry_count == max_retries: logger.info("Max retries reached. Moving to next round.") score = None diff --git a/metagpt/logs.py b/metagpt/logs.py index ce2f12e4c0..bfb44d1227 100644 --- a/metagpt/logs.py +++ b/metagpt/logs.py @@ -149,5 +149,5 @@ def get_llm_stream_queue(): def _llm_stream_log(msg): - if _print_level in ["INFO"]: + if _print_level in ["INFO", "DEBUG"]: print(msg, end="") diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 6560fc7ddd..2c82a8b398 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -30,6 +30,7 @@ from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.token_counter import TOKEN_MAX +from metagpt.utils.rate_limitor import RateLimitor, rate_limitor_registry class BaseLLM(ABC): @@ -44,6 +45,7 @@ class BaseLLM(ABC): cost_manager: Optional[CostManager] = None model: Optional[str] = None # deprecated pricing_plan: Optional[str] = None + current_rate_limitor: Optional[RateLimitor] = None _reasoning_content: Optional[str] = None # content from reasoning mode @@ -134,6 +136,7 @@ def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_ prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) self.cost_manager.update_cost(prompt_tokens, completion_tokens, model) + self.rate_limitor.cost_token(usage) except Exception as e: logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}") @@ -197,11 +200,13 @@ async def aask( message.extend(msg) if stream is None: stream = self.config.stream - + async with self.rate_limitor: + await self.rate_limitor.acquire(message) + # the image data is replaced with placeholders to avoid long output masked_message = [self.mask_base64_data(m) for m in message] logger.debug(masked_message) - + compressed_message = self.compress_messages(message, compress_type=self.config.compress_type) rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) @@ -323,6 +328,12 @@ def with_model(self, model: str): """Set model and return self. For example, `with_model("gpt-3.5-turbo")`.""" self.config.model = model return self + + @property + def rate_limitor(self) -> RateLimitor: + if not self.current_rate_limitor: + self.current_rate_limitor = rate_limitor_registry.register(None, self.config) + return self.current_rate_limitor def get_timeout(self, timeout: int) -> int: return timeout or self.config.timeout or LLM_API_TIMEOUT @@ -407,4 +418,4 @@ def compress_messages( ) break - return compressed + return compressed \ No newline at end of file diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index cf2aa58bab..39e2aa4d80 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -31,6 +31,8 @@ from asyncio import iscoroutinefunction from datetime import datetime from functools import partial +import asyncio +import nest_asyncio from io import BytesIO from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -672,6 +674,15 @@ def format_trackback_info(limit: int = 2): return traceback.format_exc(limit=limit) +def asyncio_run(future): + nest_asyncio.apply() + try: + loop = asyncio.get_event_loop() + return loop.run_until_complete(future) + except RuntimeError: + return asyncio.run(future) + + def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: diff --git a/metagpt/utils/rate_limitor.py b/metagpt/utils/rate_limitor.py new file mode 100644 index 0000000000..6525fadb99 --- /dev/null +++ b/metagpt/utils/rate_limitor.py @@ -0,0 +1,169 @@ +import time +import asyncio +import math +import json + +from pydantic_core import to_jsonable_python +from metagpt.utils.token_counter import count_message_tokens +from metagpt.configs.llm_config import LLMConfig +from metagpt.logs import logger +from metagpt.configs.models_config import ModelsConfig +import metagpt.utils.common as common + +class RateLimitor: + def __init__(self, rpm: int, tpm: int): + self.rpm = rpm + self.tpm = tpm + self.tpm_bucket = TokenBucket(tpm) + self.rpm_bucket = TokenBucket(rpm) + self.lock = asyncio.Semaphore(rpm) + + async def acquire_rpm(self, tokens=1): + await self.rpm_bucket.acquire(tokens) + + + async def __enter__(self): + if self.rpm > 0 or self.tpm > 0: + await self.lock.acquire() + return self + + async def __exit__(self, exc_type, exc_val, exc_tb): + if self.rpm > 0 or self.tpm > 0: + self.lock.release() + return None + + async def __aenter__(self): + return await self.__enter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.__exit__(exc_type, exc_val, exc_tb) + + def cost_token(self, usage: dict): + if not isinstance(usage, dict): + usage = dict(usage) + self.tpm_bucket._cost(usage.get("input_tokens", usage.get('prompt_tokens', 0))) + self.tpm_bucket._cost(usage.get("output_tokens", usage.get('completion_tokens', 0))) + + + async def acquire(self, messages): + tokens = count_message_tokens(messages) + await self.tpm_bucket._wait(tokens) + await self.acquire_rpm(1) + + +class TokenBucket: + def __init__(self, rpm): + """ + Initialize the token bucket (thread-safe version) + :param rpm: the number of requests per minute + """ + if rpm is None: + rpm = 0 + self.capacity = rpm # the capacity of the bucket + self.tokens = rpm # the current number of tokens + self.rate = rpm / 60.0 if rpm else 0 # the number of tokens generated per second + self.last_refill = time.time() + self.lock = asyncio.Lock() # 线程安全锁 + + async def _refill(self, desc_tokens=0): + async with self.lock: + """Refill the tokens (need to be called in the lock protected context)""" + if self.capacity is None or self.capacity <= 0: + return + # assert self.capacity >= desc_tokens, f"令牌桶的容量[{self.capacity}]无法支撑该次请求的消耗:{desc_tokens}." + now = time.time() + elapsed = now - self.last_refill + new_tokens = elapsed * self.rate + + if new_tokens + self.tokens >= desc_tokens or self.tokens >= self.capacity: + self.tokens = min(self.capacity, self.tokens + new_tokens) - desc_tokens + self.last_refill = now + return True # 表示有新增令牌 + else: + self.tokens = min(self.capacity, self.tokens + new_tokens) + self.last_refill = now + return False + + def _cost(self, tokens: int): + if self.capacity is None or self.capacity <= 0: + return + assert tokens >= 0 + common.asyncio_run(self._refill()) + self.tokens -= tokens + + async def _wait(self, tokens: int): + while True: + if await self._refill(desc_tokens=tokens): + # enough tokens, return immediately + return True + deficit = tokens - self.tokens + wait_time = deficit / self.rate + + logger.warning(f"current [{asyncio.current_task().get_name()}] with [{self.tokens:.5f}] tokens, wait_time for tpm: {wait_time:.3f}") + await asyncio.sleep(wait_time) + + async def acquire(self, tokens=1): + """ + Block until acquiring the specified number of tokens + :param tokens: the number of tokens needed (default is 1) + """ + if self.capacity is None or self.capacity <= 0: + return + + while True: + # if the tokens are enough, return immediately + if await self._refill(desc_tokens=tokens): + return + + # calculate the time to wait + deficit = tokens - self.tokens + wait_time = deficit / self.rate + + logger.warning(f"current [{asyncio.current_task().get_name()}] with [{self.tokens:.5f}] tokens, wait_time for rpm: {wait_time:.3f}") + + # wait until the tokens are replenished (with timeout and notification) + await asyncio.sleep(wait_time) + + @property + def available_tokens(self): + """Get the current number of available tokens (refreshed in real time)""" + if self.capacity is None or self.capacity <= 0: + return math.inf + common.asyncio_run(self._refill()) + return self.tokens + + +class RateLimitorRegistry: + def __init__(self): + self.rate_limitors = {} + self.config_items = {} + + def init_rate_limitors(self): + for model_name, llm_config in ModelsConfig.default().items(): + self.register(model_name, llm_config) + + def _config_to_key(self, llm_config: LLMConfig): + return json.dumps(llm_config.model_dump(), default=to_jsonable_python) + + def register(self, model_name: str, llm_config: LLMConfig) -> RateLimitor: + if not llm_config: + raise ValueError("llm_config is required") + if not model_name: + model_name = self._config_to_key(llm_config) + if model_name not in self.rate_limitors: + self.rate_limitors[model_name] = RateLimitor(llm_config.rpm, llm_config.tpm) + self.config_items[self._config_to_key(llm_config)] = model_name + return self.rate_limitors[model_name] + + def get(self, model_name: str): + if not model_name: + model_name = "_default_llm" + return self.rate_limitors.get(model_name) + + def get_by_config(self, llm_config: LLMConfig): + rate_limitor_key = self._config_to_key(llm_config) + return self.rate_limitors.get(rate_limitor_key, default_rate_limitor) + +rate_limitor_registry = RateLimitorRegistry() + +default_rate_limitor = RateLimitor(0, 0) \ No newline at end of file diff --git a/tests/data/code/python/1.py b/tests/data/code/python/1.py index e9aeaeeeed..697084fc06 100644 --- a/tests/data/code/python/1.py +++ b/tests/data/code/python/1.py @@ -48,8 +48,9 @@ ax2.set_xlabel("Degree") ax2.set_ylabel("# of Nodes") -fig.tight_layout() -plt.show() +if __name__ == "__main__": + fig.tight_layout() + plt.show() class Game: