diff --git a/src/core/orchestrator.py b/src/core/orchestrator.py index 70eb67a0..aeb408d4 100644 --- a/src/core/orchestrator.py +++ b/src/core/orchestrator.py @@ -386,6 +386,9 @@ async def run_sub_agent( # Start new sub-agent session self.task_log.start_sub_agent_session(sub_agent_name, task_description) + + # Reset sub-agent usage stats for independent tracking + self.sub_agent_llm_client.reset_usage_stats() # Simplified initial user content (no file attachments) initial_user_content = [{"type": "text", "text": task_description}] @@ -661,6 +664,14 @@ async def run_sub_agent( ] = {"system_prompt": system_prompt, "message_history": message_history} # type: ignore self.task_log.save() + # Record sub-agent cumulative usage + usage_log = self.sub_agent_llm_client.get_usage_log() + self.task_log.log_step( + "usage_calculation", + usage_log, + metadata={"session_id": self.task_log.current_sub_agent_session_id}, + ) + self.task_log.end_sub_agent_session(sub_agent_name) self.task_log.log_step( "sub_agent_completed", f"Sub agent {sub_agent_name} completed", "info" @@ -682,6 +693,9 @@ async def run_main_agent( if task_file_name: logger.debug(f"Associated File: {task_file_name}") + # Reset main agent usage stats for independent tracking + self.llm_client.reset_usage_stats() + # 1. Process input initial_user_content, task_description = process_input( task_description, task_file_name @@ -1089,6 +1103,14 @@ async def run_main_agent( "task_completed", f"Main agent task {task_id} completed successfully" ) + # Record main agent cumulative usage + usage_log = self.llm_client.get_usage_log() + self.task_log.log_step( + "usage_calculation", + usage_log, + metadata={"session_id": "main_agent"}, + ) + if "browsecomp-zh" in self.cfg.benchmark.name: return final_summary, final_summary else: diff --git a/src/llm/provider_client_base.py b/src/llm/provider_client_base.py index 5d07fcec..e7808cfb 100644 --- a/src/llm/provider_client_base.py +++ b/src/llm/provider_client_base.py @@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC): # post_init client: Any = dataclasses.field(init=False) + # Usage tracking - cumulative for each agent session + total_input_tokens: int = dataclasses.field(init=False, default=0) + total_input_cached_tokens: int = dataclasses.field(init=False, default=0) + total_output_tokens: int = dataclasses.field(init=False, default=0) + total_output_reasoning_tokens: int = dataclasses.field(init=False, default=0) def __post_init__(self): # Explicitly assign from cfg object @@ -195,6 +200,19 @@ async def create_message( tool_definitions, keep_tool_result=keep_tool_result, ) + + # Accumulate usage for agent session + if response: + try: + usage = self._extract_usage_from_response(response) + if usage: + self.total_input_tokens += usage.get("input_tokens", 0) + self.total_input_cached_tokens += usage.get("cached_tokens", 0) + self.total_output_tokens += usage.get("output_tokens", 0) + self.total_output_reasoning_tokens += usage.get("reasoning_tokens", 0) + except Exception as e: + logger.warning(f"Failed to accumulate usage: {e}") + return response @staticmethod @@ -314,3 +332,50 @@ def handle_max_turns_reached_summary_prompt( self, message_history: list[dict[str, Any]], summary_prompt: str ): raise NotImplementedError("must implement in subclass") + + def _extract_usage_from_response(self, response): + """Default Extract usage - OpenAI Chat Completions format""" + if not hasattr(response, 'usage'): + return { + "input_tokens": 0, + "cached_tokens": 0, + "output_tokens": 0, + "reasoning_tokens": 0 + } + + usage = response.usage + prompt_tokens_details = getattr(usage, 'prompt_tokens_details', {}) or {} + if hasattr(prompt_tokens_details, "to_dict"): + prompt_tokens_details = prompt_tokens_details.to_dict() + completion_tokens_details = getattr(usage, 'completion_tokens_details', {}) or {} + if hasattr(completion_tokens_details, "to_dict"): + completion_tokens_details = completion_tokens_details.to_dict() + + usage_dict = { + "input_tokens": getattr(usage, 'prompt_tokens', 0), + "cached_tokens": prompt_tokens_details.get('cached_tokens', 0), + "output_tokens": getattr(usage, 'completion_tokens', 0), + "reasoning_tokens": completion_tokens_details.get('reasoning_tokens', 0) + } + + return usage_dict + + def get_usage_log(self) -> str: + """Get cumulative usage for current agent session as formatted string""" + # Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ... + provider_model = f"[{self.provider_class} | {self.model_name}]" + input_uncached = self.total_input_tokens - self.total_input_cached_tokens + output_response = self.total_output_tokens - self.total_output_reasoning_tokens + total_tokens = self.total_input_tokens + self.total_output_tokens + + return (f"Usage log: {provider_model}, " + f"Total Input: {self.total_input_tokens} (Cached: {self.total_input_cached_tokens}, Uncached: {input_uncached}), " + f"Total Output: {self.total_output_tokens} (Reasoning: {self.total_output_reasoning_tokens}, Response: {output_response}), " + f"Total Tokens: {total_tokens}") + + def reset_usage_stats(self): + """Reset usage stats for new agent session""" + self.total_input_tokens = 0 + self.total_input_cached_tokens = 0 + self.total_output_tokens = 0 + self.total_output_reasoning_tokens = 0 diff --git a/src/llm/providers/claude_anthropic_client.py b/src/llm/providers/claude_anthropic_client.py index d701d492..9bf2ee55 100644 --- a/src/llm/providers/claude_anthropic_client.py +++ b/src/llm/providers/claude_anthropic_client.py @@ -29,7 +29,7 @@ def __post_init__(self): def _create_client(self, config: DictConfig): """Create Anthropic client""" - api_key = config.env.anthropic_api_key + api_key = self.cfg.llm.anthropic_api_key if self.async_client: return AsyncAnthropic( @@ -183,6 +183,31 @@ def handle_max_turns_reached_summary_prompt(self, message_history, summary_promp else: return summary_prompt + def _extract_usage_from_response(self, response): + """Extract usage - Anthropic format""" + if not hasattr(response, 'usage'): + return { + "input_tokens": 0, + "cached_tokens": 0, + "output_tokens": 0, + "reasoning_tokens": 0 + } + + usage = response.usage + cache_creation_input_tokens = getattr(usage, 'cache_creation_input_tokens', 0) + cache_read_input_tokens = getattr(usage, 'cache_read_input_tokens', 0) + input_tokens = getattr(usage, 'input_tokens', 0) + output_tokens = getattr(usage, 'output_tokens', 0) + + usage_dict = { + "input_tokens": cache_creation_input_tokens + cache_read_input_tokens + input_tokens, + "cached_tokens": cache_read_input_tokens, + "output_tokens": output_tokens, + "reasoning_tokens": 0 + } + + return usage_dict + def _apply_cache_control(self, messages): """Apply cache control to the last user message and system message (if applicable)""" cached_messages = [] diff --git a/src/llm/providers/claude_newapi_client.py b/src/llm/providers/claude_newapi_client.py index 72814465..65c0a4c3 100644 --- a/src/llm/providers/claude_newapi_client.py +++ b/src/llm/providers/claude_newapi_client.py @@ -35,13 +35,13 @@ class ClaudeNewAPIClient(LLMProviderClientBase): def _create_client(self, config: DictConfig): if self.async_client: return AsyncOpenAI( - api_key=config.env.newapi_api_key, - base_url=config.env.newapi_base_url, + api_key=self.cfg.llm.newapi_api_key, + base_url=self.cfg.llm.newapi_base_url, ) else: return OpenAI( - api_key=config.env.newapi_api_key, - base_url=config.env.newapi_base_url, + api_key=self.cfg.llm.newapi_api_key, + base_url=self.cfg.llm.newapi_base_url, ) # @retry(wait=wait_fixed(10), stop=stop_after_attempt(5)) diff --git a/src/llm/providers/claude_openrouter_client.py b/src/llm/providers/claude_openrouter_client.py index 4acd114e..2b1c680b 100644 --- a/src/llm/providers/claude_openrouter_client.py +++ b/src/llm/providers/claude_openrouter_client.py @@ -133,6 +133,10 @@ async def _create_message( if self.repetition_penalty != 1.0: extra_body["repetition_penalty"] = self.repetition_penalty + extra_body["usage"] = { + "include": True + } + params = { "model": self.model_name, "temperature": temperature, diff --git a/src/llm/providers/deepseek_newapi_client.py b/src/llm/providers/deepseek_newapi_client.py index b8bf916a..361e9310 100644 --- a/src/llm/providers/deepseek_newapi_client.py +++ b/src/llm/providers/deepseek_newapi_client.py @@ -32,13 +32,13 @@ def _create_client(self, config: DictConfig): """Create configured OpenAI client""" if self.async_client: return AsyncOpenAI( - api_key=config.env.newapi_api_key, - base_url=config.env.newapi_base_url, + api_key=self.cfg.llm.newapi_api_key, + base_url=self.cfg.llm.newapi_base_url, ) else: return OpenAI( - api_key=config.env.newapi_api_key, - base_url=config.env.newapi_base_url, + api_key=self.cfg.llm.newapi_api_key, + base_url=self.cfg.llm.newapi_base_url, ) # @retry(wait=wait_fixed(10), stop=stop_after_attempt(5)) diff --git a/src/llm/providers/gpt_openai_client.py b/src/llm/providers/gpt_openai_client.py index b404e5f2..99f9c190 100644 --- a/src/llm/providers/gpt_openai_client.py +++ b/src/llm/providers/gpt_openai_client.py @@ -29,13 +29,13 @@ def _create_client(self, config: DictConfig): """Create configured OpenAI client""" if self.async_client: return AsyncOpenAI( - api_key=config.env.openai_api_key, - base_url=config.env.openai_base_url, + api_key=self.cfg.llm.openai_api_key, + base_url=self.cfg.llm.openai_base_url, ) else: return OpenAI( - api_key=config.env.openai_api_key, - base_url=config.env.openai_base_url, + api_key=self.cfg.llm.openai_api_key, + base_url=self.cfg.llm.openai_base_url, ) @retry(wait=wait_fixed(10), stop=stop_after_attempt(5)) diff --git a/src/llm/providers/gpt_openai_response_client.py b/src/llm/providers/gpt_openai_response_client.py index 987e136f..3f7c16fc 100644 --- a/src/llm/providers/gpt_openai_response_client.py +++ b/src/llm/providers/gpt_openai_response_client.py @@ -29,12 +29,12 @@ def _create_client(self, config: DictConfig): """Create configured OpenAI client""" if self.async_client: return AsyncOpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), + api_key=self.cfg.llm.openai_api_key, base_url=self.cfg.llm.openai_base_url, ) else: return OpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), + api_key=self.cfg.llm.openai_api_key, base_url=self.cfg.llm.openai_base_url, ) @@ -93,7 +93,7 @@ async def _create_message( response = self._convert_response_to_serializable(response) # Update token count - self._update_token_usage(response.get("usage", None)) + # self._update_token_usage(response.get("usage", None)) logger.debug( f"LLM Response API call status: {response.get('error', 'N/A')}" ) @@ -269,3 +269,30 @@ def _convert_response_to_serializable(self, response): } return serializable_response + + def _extract_usage_from_response(self, response): + """Extract usage - OpenAI Responses API format""" + if not response or not response.get('usage'): + return { + "input_tokens": 0, + "cached_tokens": 0, + "output_tokens": 0, + "reasoning_tokens": 0 + } + + usage = response.get('usage', {}) or {} + input_tokens_details = usage.get('input_tokens_details', {}) or {} + if hasattr(input_tokens_details, "to_dict"): + input_tokens_details = input_tokens_details.to_dict() + output_tokens_details = usage.get('output_tokens_details', {}) or {} + if hasattr(output_tokens_details, "to_dict"): + output_tokens_details = output_tokens_details.to_dict() + + usage_dict = { + "input_tokens": usage.get('input_tokens', 0), + "cached_tokens": input_tokens_details.get('cached_tokens', 0), + "output_tokens": usage.get('output_tokens', 0), + "reasoning_tokens": output_tokens_details.get('reasoning_tokens', 0) + } + + return usage_dict diff --git a/src/llm/providers/qwen_sglang_client.py b/src/llm/providers/qwen_sglang_client.py index bd365238..b560d07c 100644 --- a/src/llm/providers/qwen_sglang_client.py +++ b/src/llm/providers/qwen_sglang_client.py @@ -25,13 +25,13 @@ def _create_client(self, config: DictConfig): """Create configured OpenAI client""" if self.async_client: return AsyncOpenAI( - api_key=config.env.qwen_api_key, - base_url=config.env.qwen_base_url, + api_key=self.cfg.llm.qwen_api_key, + base_url=self.cfg.llm.qwen_base_url, ) else: return OpenAI( - api_key=config.env.qwen_api_key, - base_url=config.env.qwen_base_url, + api_key=self.cfg.llm.qwen_api_key, + base_url=self.cfg.llm.qwen_base_url, ) @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))