Skip to content

Commit c42726a

Browse files
authored
feat: Introduce OpenRouterConversation with Force Tool Call and Token Counting Overhaul
* Add open router conversation * Add possiblity to force tool usage behavior * Add full AIMessage when possible to collect all of the metadata * Harmonize tests with newly introduced code * Add tests for OpenRouterConversation * Simplify token counting strategy * Update behavior of token count for structured outputs * Update tests to new changes * Harmonize handling of response messages * Update tests to new changes * Harmonize token count reporting across different conversations * Add a suite of test to verify consisten token reporting behavior across conversations * Introduce auxiliary methods to extract input and output tokens when possible * Update tests to cover new functions * Add a function to compute comulative token usage * Fix the docs * Add author entry to pyproject.toml (hooray!) * Add warnining messages to conversation classes that don't use kwargs * Add tests to check whether warnings are appearing correctly
1 parent 8648061 commit c42726a

20 files changed

+2463
-63
lines changed

biochatter/llm_connect/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from biochatter.llm_connect.misc import BloomConversation, WasmConversation
1010
from biochatter.llm_connect.ollama import OllamaConversation
1111
from biochatter.llm_connect.openai import GptConversation
12+
from biochatter.llm_connect.openrouter import OpenRouterConversation
1213
from biochatter.llm_connect.xinference import XinferenceConversation
1314

1415
__all__ = [
@@ -21,6 +22,7 @@
2122
"LangChainConversation",
2223
"LiteLLMConversation",
2324
"OllamaConversation",
25+
"OpenRouterConversation",
2426
"WasmConversation",
2527
"XinferenceConversation",
2628
]

biochatter/llm_connect/anthropic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import anthropic
24
from langchain_anthropic import ChatAnthropic
35
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@@ -104,6 +106,9 @@ def _primary_query(self, **kwargs) -> tuple:
104106
the token usage.
105107
106108
"""
109+
if kwargs:
110+
warnings.warn(f"Warning: {kwargs} are not used by this class", UserWarning)
111+
107112
try:
108113
history = self._create_history()
109114
response = self.chat.generate([history])
@@ -126,7 +131,8 @@ def _primary_query(self, **kwargs) -> tuple:
126131
return str(e), None
127132

128133
msg = response.generations[0][0].text
129-
token_usage = response.llm_output.get("token_usage")
134+
token_usage_raw = response.llm_output.get("token_usage")
135+
token_usage = self._extract_total_tokens(token_usage_raw)
130136

131137
self.append_ai_message(msg)
132138

biochatter/llm_connect/conversation.py

Lines changed: 202 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
tool_call_mode: Literal["auto", "text"] = "auto",
111111
mcp: bool = False,
112112
additional_tools_instructions: str = None,
113+
force_tool: bool = False,
113114
) -> None:
114115
super().__init__()
115116
self.model_name = model_name
@@ -130,6 +131,7 @@ def __init__(
130131
self.tools_prompt = None
131132
self.mcp = mcp
132133
self.additional_tools_instructions = additional_tools_instructions if additional_tools_instructions else ""
134+
self.force_tool = force_tool
133135

134136
@property
135137
def chat(self):
@@ -194,6 +196,188 @@ def find_rag_agent(self, mode: str) -> tuple[int, RagAgent]:
194196
return i, val
195197
return -1, None
196198

199+
def _extract_total_tokens(self, token_usage: dict | int | None) -> int | None:
200+
"""Extract total tokens from various token usage formats.
201+
202+
This method standardizes token counting across different providers:
203+
- OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
204+
- Anthropic: {"input_tokens": X, "output_tokens": Y} -> calculate total
205+
- Gemini: {"total_tokens": Z} -> extract total
206+
- Ollama: integer (eval_count) -> return as is
207+
- LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
208+
- Others: try to extract or calculate total
209+
210+
Args:
211+
----
212+
token_usage: Token usage in various formats (dict, int, or None)
213+
214+
Returns:
215+
-------
216+
int | None: Total token count, or None if not available
217+
218+
"""
219+
if token_usage is None:
220+
return None
221+
222+
# Handle integer token counts (Ollama, some others)
223+
if isinstance(token_usage, int):
224+
return token_usage
225+
226+
# Handle dictionary token counts
227+
if isinstance(token_usage, dict):
228+
# First try to get total_tokens directly
229+
if "total_tokens" in token_usage:
230+
return token_usage["total_tokens"]
231+
232+
# Calculate from input/output tokens (Anthropic style)
233+
if "input_tokens" in token_usage and "output_tokens" in token_usage:
234+
return token_usage["input_tokens"] + token_usage["output_tokens"]
235+
236+
# Calculate from prompt/completion tokens (OpenAI style fallback)
237+
if "prompt_tokens" in token_usage and "completion_tokens" in token_usage:
238+
return token_usage["prompt_tokens"] + token_usage["completion_tokens"]
239+
240+
# If only one type of token count is available, use it
241+
if "input_tokens" in token_usage:
242+
return token_usage["input_tokens"]
243+
if "output_tokens" in token_usage:
244+
return token_usage["output_tokens"]
245+
if "prompt_tokens" in token_usage:
246+
return token_usage["prompt_tokens"]
247+
if "completion_tokens" in token_usage:
248+
return token_usage["completion_tokens"]
249+
250+
# If we can't extract meaningful token count, return None
251+
return None
252+
253+
def _extract_input_tokens(self, token_usage: dict | int | None) -> int | None:
254+
"""Extract input tokens from various token usage formats.
255+
256+
This method standardizes input token counting across different providers:
257+
- OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
258+
- Anthropic: {"input_tokens": X, "output_tokens": Y}
259+
- Gemini: {"prompt_tokens": X, "candidates_tokens": Y, "total_tokens": Z}
260+
- LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
261+
- Others: try to extract input/prompt tokens
262+
263+
Args:
264+
----
265+
token_usage: Token usage in various formats (dict, int, or None)
266+
267+
Returns:
268+
-------
269+
int | None: Input token count, or None if not available
270+
271+
"""
272+
if token_usage is None:
273+
return None
274+
275+
# Handle integer token counts (cannot distinguish input vs output)
276+
if isinstance(token_usage, int):
277+
return None
278+
279+
# Handle dictionary token counts
280+
if isinstance(token_usage, dict):
281+
# First try to get input_tokens (Anthropic, LiteLLM style)
282+
if "input_tokens" in token_usage:
283+
return token_usage["input_tokens"]
284+
285+
# Try prompt_tokens (OpenAI style)
286+
if "prompt_tokens" in token_usage:
287+
return token_usage["prompt_tokens"]
288+
289+
# If we can't extract meaningful input token count, return None
290+
return None
291+
292+
def _extract_output_tokens(self, token_usage: dict | int | None) -> int | None:
293+
"""Extract output tokens from various token usage formats.
294+
295+
This method standardizes output token counting across different providers:
296+
- OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
297+
- Anthropic: {"input_tokens": X, "output_tokens": Y}
298+
- Gemini: {"prompt_tokens": X, "candidates_tokens": Y, "total_tokens": Z}
299+
- LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
300+
- Others: try to extract output/completion tokens
301+
302+
Args:
303+
----
304+
token_usage: Token usage in various formats (dict, int, or None)
305+
306+
Returns:
307+
-------
308+
int | None: Output token count, or None if not available
309+
310+
"""
311+
if token_usage is None:
312+
return None
313+
314+
# Handle integer token counts (cannot distinguish input vs output)
315+
if isinstance(token_usage, int):
316+
return None
317+
318+
# Handle dictionary token counts
319+
if isinstance(token_usage, dict):
320+
# First try to get output_tokens (Anthropic, LiteLLM style)
321+
if "output_tokens" in token_usage:
322+
return token_usage["output_tokens"]
323+
324+
# Try completion_tokens (OpenAI style)
325+
if "completion_tokens" in token_usage:
326+
return token_usage["completion_tokens"]
327+
328+
# Try candidates_tokens (Gemini style)
329+
if "candidates_tokens" in token_usage:
330+
return token_usage["candidates_tokens"]
331+
332+
# If we can't extract meaningful output token count, return None
333+
return None
334+
335+
def compute_cumulative_token_usage(self) -> dict:
336+
"""Compute the token usage by looping over the messages.
337+
338+
Extracts token usage information from each message's usage_metadata and
339+
computes running cumulative totals throughout the conversation.
340+
Handles various token usage formats from different LLM providers.
341+
342+
Returns
343+
-------
344+
dict: Token usage information with lists of running totals:
345+
- "total_tokens": list[int] - running total at each message
346+
- "input_tokens": list[int] - running input total at each message
347+
- "output_tokens": list[int] - running output total at each message
348+
349+
"""
350+
# Initialize data structures
351+
individual_usage = {
352+
"total_tokens": [],
353+
"input_tokens": [],
354+
"output_tokens": [],
355+
}
356+
357+
# Extract individual token counts for each AI message
358+
for message in self.messages:
359+
if isinstance(message, AIMessage):
360+
usage_metadata = getattr(message, "usage_metadata", None)
361+
individual_usage["total_tokens"].append(self._extract_total_tokens(usage_metadata))
362+
individual_usage["input_tokens"].append(self._extract_input_tokens(usage_metadata))
363+
individual_usage["output_tokens"].append(self._extract_output_tokens(usage_metadata))
364+
365+
# Compute running cumulative totals for each message
366+
per_message_cumulative = {
367+
"total_tokens": [],
368+
"input_tokens": [],
369+
"output_tokens": [],
370+
}
371+
372+
for token_type in ["total_tokens", "input_tokens", "output_tokens"]:
373+
running_total = 0
374+
for count in individual_usage[token_type]:
375+
if count is not None:
376+
running_total += count
377+
per_message_cumulative[token_type].append(running_total)
378+
379+
return per_message_cumulative
380+
197381
@abstractmethod
198382
def set_api_key(self, api_key: str, user: str | None = None) -> None:
199383
"""Set the API key."""
@@ -253,19 +437,24 @@ def bind_tools(self, tools: list[Callable]) -> None:
253437
# If not, fail gracefully
254438
# raise ValueError(f"Model {self.model_name} does not support tool calling.")
255439

256-
def append_ai_message(self, message: str) -> None:
440+
def append_ai_message(self, message: str | AIMessage) -> None:
257441
"""Add a message from the AI to the conversation.
258442
259443
Args:
260444
----
261445
message (str): The message from the AI.
262446
263447
"""
264-
self.messages.append(
265-
AIMessage(
266-
content=message,
267-
),
268-
)
448+
if isinstance(message, AIMessage):
449+
self.messages.append(message)
450+
elif isinstance(message, str):
451+
self.messages.append(
452+
AIMessage(
453+
content=message,
454+
),
455+
)
456+
else:
457+
raise ValueError(f"Invalid message type: {type(message)}")
269458

270459
def append_system_message(self, message: str) -> None:
271460
"""Add a system message to the conversation.
@@ -473,9 +662,13 @@ def query(
473662
track_tool_calls=track_tool_calls,
474663
)
475664

665+
# case of structured output
666+
if (token_usage == -1) and structured_model:
667+
return (msg, 0, None)
668+
476669
if not token_usage:
477670
# indicates error
478-
return (msg, token_usage, None)
671+
return (msg, None, None)
479672

480673
if not self.correct:
481674
return (msg, token_usage, None)
@@ -712,7 +905,7 @@ def _process_tool_calls(
712905
additional_instructions=self.additional_instructions_tool_interpretation,
713906
)
714907
)
715-
self.append_ai_message(tool_result_interpretation.content)
908+
self.messages.append(tool_result_interpretation)
716909
msg += f"\nTool results interpretation: {tool_result_interpretation.content}"
717910
else:
718911
# Single tool: explain individual result (maintain current behavior)
@@ -725,7 +918,7 @@ def _process_tool_calls(
725918
additional_instructions=self.additional_instructions_tool_interpretation,
726919
)
727920
)
728-
self.append_ai_message(tool_result_interpretation.content)
921+
self.messages.append(tool_result_interpretation)
729922
msg += f"\nTool result interpretation: {tool_result_interpretation.content}"
730923

731924
return msg

biochatter/llm_connect/gemini.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Callable
23
from typing import Literal
34

@@ -119,6 +120,10 @@ def _primary_query(self, tools: list[Callable] | None = None, **kwargs) -> tuple
119120
the token usage.
120121
121122
"""
123+
if kwargs:
124+
kwargs.pop("tools", None)
125+
warnings.warn(f"Warning: {kwargs} are not used by this class", UserWarning)
126+
122127
# bind tools to the chat if provided in the query
123128
chat = self.chat.bind_tools(tools) if (tools and self.model_name in TOOL_CALLING_MODELS) else self.chat
124129

@@ -134,7 +139,8 @@ def _primary_query(self, tools: list[Callable] | None = None, **kwargs) -> tuple
134139
msg = response.content
135140
self.append_ai_message(msg)
136141

137-
token_usage = response.usage_metadata["total_tokens"]
142+
token_usage_raw = response.usage_metadata
143+
token_usage = self._extract_total_tokens(token_usage_raw)
138144

139145
return msg, token_usage
140146

@@ -171,6 +177,7 @@ def _correct_response(self, msg: str) -> str:
171177
response = self.ca_chat.invoke(ca_messages)
172178

173179
correction = response.content
174-
token_usage = response.usage_metadata["total_tokens"]
180+
token_usage_raw = response.usage_metadata
181+
token_usage = self._extract_total_tokens(token_usage_raw)
175182

176183
return correction

0 commit comments

Comments
 (0)