@@ -110,6 +110,7 @@ def __init__(
110110 tool_call_mode : Literal ["auto" , "text" ] = "auto" ,
111111 mcp : bool = False ,
112112 additional_tools_instructions : str = None ,
113+ force_tool : bool = False ,
113114 ) -> None :
114115 super ().__init__ ()
115116 self .model_name = model_name
@@ -130,6 +131,7 @@ def __init__(
130131 self .tools_prompt = None
131132 self .mcp = mcp
132133 self .additional_tools_instructions = additional_tools_instructions if additional_tools_instructions else ""
134+ self .force_tool = force_tool
133135
134136 @property
135137 def chat (self ):
@@ -194,6 +196,188 @@ def find_rag_agent(self, mode: str) -> tuple[int, RagAgent]:
194196 return i , val
195197 return - 1 , None
196198
199+ def _extract_total_tokens (self , token_usage : dict | int | None ) -> int | None :
200+ """Extract total tokens from various token usage formats.
201+
202+ This method standardizes token counting across different providers:
203+ - OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
204+ - Anthropic: {"input_tokens": X, "output_tokens": Y} -> calculate total
205+ - Gemini: {"total_tokens": Z} -> extract total
206+ - Ollama: integer (eval_count) -> return as is
207+ - LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
208+ - Others: try to extract or calculate total
209+
210+ Args:
211+ ----
212+ token_usage: Token usage in various formats (dict, int, or None)
213+
214+ Returns:
215+ -------
216+ int | None: Total token count, or None if not available
217+
218+ """
219+ if token_usage is None :
220+ return None
221+
222+ # Handle integer token counts (Ollama, some others)
223+ if isinstance (token_usage , int ):
224+ return token_usage
225+
226+ # Handle dictionary token counts
227+ if isinstance (token_usage , dict ):
228+ # First try to get total_tokens directly
229+ if "total_tokens" in token_usage :
230+ return token_usage ["total_tokens" ]
231+
232+ # Calculate from input/output tokens (Anthropic style)
233+ if "input_tokens" in token_usage and "output_tokens" in token_usage :
234+ return token_usage ["input_tokens" ] + token_usage ["output_tokens" ]
235+
236+ # Calculate from prompt/completion tokens (OpenAI style fallback)
237+ if "prompt_tokens" in token_usage and "completion_tokens" in token_usage :
238+ return token_usage ["prompt_tokens" ] + token_usage ["completion_tokens" ]
239+
240+ # If only one type of token count is available, use it
241+ if "input_tokens" in token_usage :
242+ return token_usage ["input_tokens" ]
243+ if "output_tokens" in token_usage :
244+ return token_usage ["output_tokens" ]
245+ if "prompt_tokens" in token_usage :
246+ return token_usage ["prompt_tokens" ]
247+ if "completion_tokens" in token_usage :
248+ return token_usage ["completion_tokens" ]
249+
250+ # If we can't extract meaningful token count, return None
251+ return None
252+
253+ def _extract_input_tokens (self , token_usage : dict | int | None ) -> int | None :
254+ """Extract input tokens from various token usage formats.
255+
256+ This method standardizes input token counting across different providers:
257+ - OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
258+ - Anthropic: {"input_tokens": X, "output_tokens": Y}
259+ - Gemini: {"prompt_tokens": X, "candidates_tokens": Y, "total_tokens": Z}
260+ - LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
261+ - Others: try to extract input/prompt tokens
262+
263+ Args:
264+ ----
265+ token_usage: Token usage in various formats (dict, int, or None)
266+
267+ Returns:
268+ -------
269+ int | None: Input token count, or None if not available
270+
271+ """
272+ if token_usage is None :
273+ return None
274+
275+ # Handle integer token counts (cannot distinguish input vs output)
276+ if isinstance (token_usage , int ):
277+ return None
278+
279+ # Handle dictionary token counts
280+ if isinstance (token_usage , dict ):
281+ # First try to get input_tokens (Anthropic, LiteLLM style)
282+ if "input_tokens" in token_usage :
283+ return token_usage ["input_tokens" ]
284+
285+ # Try prompt_tokens (OpenAI style)
286+ if "prompt_tokens" in token_usage :
287+ return token_usage ["prompt_tokens" ]
288+
289+ # If we can't extract meaningful input token count, return None
290+ return None
291+
292+ def _extract_output_tokens (self , token_usage : dict | int | None ) -> int | None :
293+ """Extract output tokens from various token usage formats.
294+
295+ This method standardizes output token counting across different providers:
296+ - OpenAI/Azure: {"prompt_tokens": X, "completion_tokens": Y, "total_tokens": Z}
297+ - Anthropic: {"input_tokens": X, "output_tokens": Y}
298+ - Gemini: {"prompt_tokens": X, "candidates_tokens": Y, "total_tokens": Z}
299+ - LiteLLM: {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}
300+ - Others: try to extract output/completion tokens
301+
302+ Args:
303+ ----
304+ token_usage: Token usage in various formats (dict, int, or None)
305+
306+ Returns:
307+ -------
308+ int | None: Output token count, or None if not available
309+
310+ """
311+ if token_usage is None :
312+ return None
313+
314+ # Handle integer token counts (cannot distinguish input vs output)
315+ if isinstance (token_usage , int ):
316+ return None
317+
318+ # Handle dictionary token counts
319+ if isinstance (token_usage , dict ):
320+ # First try to get output_tokens (Anthropic, LiteLLM style)
321+ if "output_tokens" in token_usage :
322+ return token_usage ["output_tokens" ]
323+
324+ # Try completion_tokens (OpenAI style)
325+ if "completion_tokens" in token_usage :
326+ return token_usage ["completion_tokens" ]
327+
328+ # Try candidates_tokens (Gemini style)
329+ if "candidates_tokens" in token_usage :
330+ return token_usage ["candidates_tokens" ]
331+
332+ # If we can't extract meaningful output token count, return None
333+ return None
334+
335+ def compute_cumulative_token_usage (self ) -> dict :
336+ """Compute the token usage by looping over the messages.
337+
338+ Extracts token usage information from each message's usage_metadata and
339+ computes running cumulative totals throughout the conversation.
340+ Handles various token usage formats from different LLM providers.
341+
342+ Returns
343+ -------
344+ dict: Token usage information with lists of running totals:
345+ - "total_tokens": list[int] - running total at each message
346+ - "input_tokens": list[int] - running input total at each message
347+ - "output_tokens": list[int] - running output total at each message
348+
349+ """
350+ # Initialize data structures
351+ individual_usage = {
352+ "total_tokens" : [],
353+ "input_tokens" : [],
354+ "output_tokens" : [],
355+ }
356+
357+ # Extract individual token counts for each AI message
358+ for message in self .messages :
359+ if isinstance (message , AIMessage ):
360+ usage_metadata = getattr (message , "usage_metadata" , None )
361+ individual_usage ["total_tokens" ].append (self ._extract_total_tokens (usage_metadata ))
362+ individual_usage ["input_tokens" ].append (self ._extract_input_tokens (usage_metadata ))
363+ individual_usage ["output_tokens" ].append (self ._extract_output_tokens (usage_metadata ))
364+
365+ # Compute running cumulative totals for each message
366+ per_message_cumulative = {
367+ "total_tokens" : [],
368+ "input_tokens" : [],
369+ "output_tokens" : [],
370+ }
371+
372+ for token_type in ["total_tokens" , "input_tokens" , "output_tokens" ]:
373+ running_total = 0
374+ for count in individual_usage [token_type ]:
375+ if count is not None :
376+ running_total += count
377+ per_message_cumulative [token_type ].append (running_total )
378+
379+ return per_message_cumulative
380+
197381 @abstractmethod
198382 def set_api_key (self , api_key : str , user : str | None = None ) -> None :
199383 """Set the API key."""
@@ -253,19 +437,24 @@ def bind_tools(self, tools: list[Callable]) -> None:
253437 # If not, fail gracefully
254438 # raise ValueError(f"Model {self.model_name} does not support tool calling.")
255439
256- def append_ai_message (self , message : str ) -> None :
440+ def append_ai_message (self , message : str | AIMessage ) -> None :
257441 """Add a message from the AI to the conversation.
258442
259443 Args:
260444 ----
261445 message (str): The message from the AI.
262446
263447 """
264- self .messages .append (
265- AIMessage (
266- content = message ,
267- ),
268- )
448+ if isinstance (message , AIMessage ):
449+ self .messages .append (message )
450+ elif isinstance (message , str ):
451+ self .messages .append (
452+ AIMessage (
453+ content = message ,
454+ ),
455+ )
456+ else :
457+ raise ValueError (f"Invalid message type: { type (message )} " )
269458
270459 def append_system_message (self , message : str ) -> None :
271460 """Add a system message to the conversation.
@@ -473,9 +662,13 @@ def query(
473662 track_tool_calls = track_tool_calls ,
474663 )
475664
665+ # case of structured output
666+ if (token_usage == - 1 ) and structured_model :
667+ return (msg , 0 , None )
668+
476669 if not token_usage :
477670 # indicates error
478- return (msg , token_usage , None )
671+ return (msg , None , None )
479672
480673 if not self .correct :
481674 return (msg , token_usage , None )
@@ -712,7 +905,7 @@ def _process_tool_calls(
712905 additional_instructions = self .additional_instructions_tool_interpretation ,
713906 )
714907 )
715- self .append_ai_message (tool_result_interpretation . content )
908+ self .messages . append (tool_result_interpretation )
716909 msg += f"\n Tool results interpretation: { tool_result_interpretation .content } "
717910 else :
718911 # Single tool: explain individual result (maintain current behavior)
@@ -725,7 +918,7 @@ def _process_tool_calls(
725918 additional_instructions = self .additional_instructions_tool_interpretation ,
726919 )
727920 )
728- self .append_ai_message (tool_result_interpretation . content )
921+ self .messages . append (tool_result_interpretation )
729922 msg += f"\n Tool result interpretation: { tool_result_interpretation .content } "
730923
731924 return msg
0 commit comments