diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py index 5d36b91f4..e6c104d6a 100644 --- a/libs/community/langchain_community/callbacks/__init__.py +++ b/libs/community/langchain_community/callbacks/__init__.py @@ -57,6 +57,9 @@ from langchain_community.callbacks.mlflow_callback import ( MlflowCallbackHandler, ) + from langchain_community.callbacks.gemini_info import ( + GeminiCallbackHandler, + ) from langchain_community.callbacks.openai_info import ( OpenAICallbackHandler, ) @@ -103,6 +106,7 @@ "LLMThoughtLabeler": "langchain_community.callbacks.streamlit", "LLMonitorCallbackHandler": "langchain_community.callbacks.llmonitor_callback", "LabelStudioCallbackHandler": "langchain_community.callbacks.labelstudio_callback", + "GeminiCallbackHandler": "langchain_community.callbacks.gemini_info", "MlflowCallbackHandler": "langchain_community.callbacks.mlflow_callback", "OpenAICallbackHandler": "langchain_community.callbacks.openai_info", "PromptLayerCallbackHandler": "langchain_community.callbacks.promptlayer_callback", @@ -136,6 +140,7 @@ def __getattr__(name: str) -> Any: "ContextCallbackHandler", "FiddlerCallbackHandler", "FlyteCallbackHandler", + "GeminiCallbackHandler", "HumanApprovalCallbackHandler", "InfinoCallbackHandler", "LLMThoughtLabeler", diff --git a/libs/community/langchain_community/callbacks/gemini_info.py b/libs/community/langchain_community/callbacks/gemini_info.py new file mode 100644 index 000000000..a669c49d5 --- /dev/null +++ b/libs/community/langchain_community/callbacks/gemini_info.py @@ -0,0 +1,222 @@ +import threading +from enum import Enum, auto +from typing import Any, Dict, List + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +MODEL_COST_PER_1M_TOKENS = { + "gemini-2.5-pro": 1.25, + "gemini-2.5-pro-completion": 10, + "gemini-2.5-flash": 0.3, + "gemini-2.5-flash-completion": 2.5, + "gemini-2.5-flash-lite": 0.1, + "gemini-2.5-flash-lite-completion": 0.4, + "gemini-2.0-flash": 0.1, + "gemini-2.0-flash-completion": 0.4, + "gemini-2.0-flash-lite": 0.075, + "gemini-2.0-flash-lite-completion": 0.3, + "gemini-1.5-pro": 1.25, + "gemini-1.5-pro-completion": 5, + "gemini-1.5-flash": 0.075, + "gemini-1.5-flash-completion": 0.3, +} +MODEL_COST_PER_1K_TOKENS = {k: v / 1000 for k, v in MODEL_COST_PER_1M_TOKENS.items()} + + +class TokenType(Enum): + """Token type enum.""" + + PROMPT = auto() + PROMPT_CACHED = auto() + COMPLETION = auto() + + + +def standardize_model_name( + model_name: str, + token_type: TokenType = TokenType.PROMPT, +) -> str: + """Standardize the model name to a format that can be used in the Gemini API. + + Args: + model_name: The name of the model to standardize. + token_type: The type of token, defaults to PROMPT. + """ + model_name = model_name.lower() + if token_type == TokenType.COMPLETION: + return model_name + "-completion" + else: + return model_name + + + + +def get_gemini_token_cost_for_model( + model_name: str, + num_tokens: int, + is_completion: bool = False, + *, + token_type: TokenType = TokenType.PROMPT, +) -> float: + """Get the cost in USD for a given model and number of tokens. + + Args: + model_name: The name of the Gemini model to calculate cost for. + num_tokens: The number of tokens to calculate cost for. + is_completion: Whether the tokens are completion tokens. + If True, token_type will be set to TokenType.COMPLETION. + token_type: The type of token (prompt or completion). + Defaults to TokenType.PROMPT. + + Returns: + The cost in USD for the specified number of tokens. + + Raises: + ValueError: If the model name is not recognized as a valid Gemini model. + """ + if is_completion: + token_type = TokenType.COMPLETION + model_name = standardize_model_name(model_name, token_type=token_type) + if model_name not in MODEL_COST_PER_1K_TOKENS: + raise ValueError( + f"Unknown model: {model_name}. Please provide a valid Gemini model name. Known models are: " + + ", ".join(MODEL_COST_PER_1K_TOKENS.keys()) + ) + return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000) + + + +class GeminiCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks Gemini info.""" + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + self.total_tokens = 0 + self.prompt_tokens = 0 + self.prompt_tokens_cached = 0 + self.completion_tokens = 0 + self.reasoning_tokens = 0 + self.successful_requests = 0 + self.total_cost = 0.0 + + def __repr__(self) -> str: + return f"""Tokens Used: {self.total_tokens} +\tPrompt Tokens: {self.prompt_tokens} +\tPrompt Cached Tokens: {self.prompt_tokens_cached} +\tCompletion Tokens: {self.completion_tokens} +\tReasoning Tokens: {self.reasoning_tokens} +Successful Requests: {self.successful_requests} +Total Cost (USD): ${self.total_cost}""" + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + # Check for usage_metadata (langchain-core >= 0.2.2) + try: + generation = response.generations[0][0] + except IndexError: + generation = None + if isinstance(generation, ChatGeneration): + try: + message = generation.message + if isinstance(message, AIMessage): + usage_metadata = message.usage_metadata + response_metadata = message.response_metadata + else: + usage_metadata = None + response_metadata = None + except AttributeError: + usage_metadata = None + response_metadata = None + else: + usage_metadata = None + response_metadata = None + + prompt_tokens_cached = 0 + reasoning_tokens = 0 + + if usage_metadata: + token_usage = {"total_tokens": usage_metadata["total_tokens"]} + completion_tokens = usage_metadata["output_tokens"] + prompt_tokens = usage_metadata["input_tokens"] + if response_model_name := (response_metadata or {}).get("model_name"): + model_name = standardize_model_name(response_model_name) + elif response.llm_output is None: + model_name = "" + else: + model_name = standardize_model_name( + response.llm_output.get("model_name", "") + ) + if "cache_read" in usage_metadata.get("input_token_details", {}): + prompt_tokens_cached = usage_metadata["input_token_details"][ + "cache_read" + ] + if "reasoning" in usage_metadata.get("output_token_details", {}): + reasoning_tokens = usage_metadata["output_token_details"]["reasoning"] + else: + if response.llm_output is None: + return None + + if "token_usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 + return None + + # compute tokens and cost for this request + token_usage = response.llm_output["token_usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + model_name = standardize_model_name( + response.llm_output.get("model_name", "") + ) + + if model_name in MODEL_COST_PER_1K_TOKENS: + uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached + uncached_prompt_cost = get_gemini_token_cost_for_model( + model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT + ) + cached_prompt_cost = get_gemini_token_cost_for_model( + model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED + ) + prompt_cost = uncached_prompt_cost + cached_prompt_cost + completion_cost = get_gemini_token_cost_for_model( + model_name, completion_tokens, token_type=TokenType.COMPLETION + ) + else: + completion_cost = 0 + prompt_cost = 0 + + # update shared state behind lock + with self._lock: + self.total_cost += prompt_cost + completion_cost + self.total_tokens += token_usage.get("total_tokens", 0) + self.prompt_tokens += prompt_tokens + self.prompt_tokens_cached += prompt_tokens_cached + self.completion_tokens += completion_tokens + self.reasoning_tokens += reasoning_tokens + self.successful_requests += 1 + + def __copy__(self) -> "GeminiCallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "GeminiCallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/libs/community/langchain_community/callbacks/manager.py b/libs/community/langchain_community/callbacks/manager.py index 8e8d05256..f20140003 100644 --- a/libs/community/langchain_community/callbacks/manager.py +++ b/libs/community/langchain_community/callbacks/manager.py @@ -13,6 +13,7 @@ from langchain_community.callbacks.bedrock_anthropic_callback import ( BedrockAnthropicTokenUsageCallbackHandler, ) +from langchain_community.callbacks.gemini_info import GeminiCallbackHandler from langchain_community.callbacks.openai_info import OpenAICallbackHandler from langchain_community.callbacks.tracers.comet import CometTracer from langchain_community.callbacks.tracers.wandb import WandbTracer @@ -25,6 +26,9 @@ bedrock_anthropic_callback_var: (ContextVar)[ Optional[BedrockAnthropicTokenUsageCallbackHandler] ] = ContextVar("bedrock_anthropic_callback", default=None) +gemini_callback_var: ContextVar[Optional[GeminiCallbackHandler]] = ContextVar( + "gemini_callback", default=None +) wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( "tracing_wandb_callback", default=None ) @@ -34,6 +38,7 @@ register_configure_hook(openai_callback_var, True) register_configure_hook(bedrock_anthropic_callback_var, True) +register_configure_hook(gemini_callback_var, True) register_configure_hook( wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING" ) @@ -81,6 +86,24 @@ def get_bedrock_anthropic_callback() -> Generator[ bedrock_anthropic_callback_var.set(None) +@contextmanager +def get_gemini_callback() -> Generator[GeminiCallbackHandler, None, None]: + """Get the Gemini callback handler in a context manager. + which conveniently exposes token and cost information. + + Returns: + GeminiCallbackHandler: The Gemini callback handler. + + Example: + >>> with get_gemini_callback() as cb: + ... # Use the Gemini callback handler + """ + cb = GeminiCallbackHandler() + gemini_callback_var.set(cb) + yield cb + gemini_callback_var.set(None) + + @contextmanager def wandb_tracing_enabled( session_name: str = "default", diff --git a/libs/community/tests/unit_tests/callbacks/test_gemini_info.py b/libs/community/tests/unit_tests/callbacks/test_gemini_info.py new file mode 100644 index 000000000..578ad48dc --- /dev/null +++ b/libs/community/tests/unit_tests/callbacks/test_gemini_info.py @@ -0,0 +1,229 @@ +from unittest.mock import MagicMock +from uuid import uuid4 + +import numpy as np +import pytest +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult + +from langchain_community.callbacks.gemini_info import GeminiCallbackHandler + + +@pytest.fixture +def handler() -> GeminiCallbackHandler: + return GeminiCallbackHandler() + + +def test_on_llm_end(handler: GeminiCallbackHandler) -> None: + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": "gemini-2.5-pro", + }, + ) + handler.on_llm_end(response) + assert handler.successful_requests == 1 + assert handler.total_tokens == 3 + assert handler.prompt_tokens == 2 + assert handler.completion_tokens == 1 + assert handler.total_cost > 0 + + +def test_on_llm_end_with_chat_generation(handler: GeminiCallbackHandler) -> None: + """Test handling of ChatGeneration with usage_metadata in AIMessage. + + Note: The Gemini callback currently doesn't parse usage_metadata from + ChatGeneration messages, it only looks at llm_output["token_usage"]. + This test verifies the current behavior. + """ + response = LLMResult( + generations=[ + [ + ChatGeneration( + text="Hello, world!", + message=AIMessage( + content="Hello, world!", + usage_metadata={ + "input_tokens": 2, + "output_tokens": 2, + "total_tokens": 4, + }, + ), + ) + ] + ], + llm_output={ + "model_name": "gemini-2.5-pro", + }, + ) + handler.on_llm_end(response) + assert handler.successful_requests == 1 + # Since there's no token_usage in llm_output, tokens should be 0 + assert handler.total_tokens == 4 + assert handler.prompt_tokens == 2 + assert handler.completion_tokens == 2 + assert handler.total_cost > 0 + + +def test_on_llm_end_custom_model(handler: GeminiCallbackHandler) -> None: + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": "foo-bar", + }, + ) + handler.on_llm_end(response) + assert handler.total_cost == 0 + + +@pytest.mark.parametrize( + "model_name, expected_cost", + [ + ("gemini-2.5-pro", 0.01125), + ("gemini-1.5-pro", 0.00625), + ("gemini-2.5-flash-lite", 0.0005), + ], +) +def test_on_llm_end_gemini_model( + handler: GeminiCallbackHandler, model_name: str, expected_cost: float +) -> None: + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 1000, + "completion_tokens": 1000, + "total_tokens": 2000, + }, + "model_name": model_name, + }, + ) + handler.on_llm_end(response) + assert np.isclose(handler.total_cost, expected_cost) + + +@pytest.mark.parametrize("model_name", ["unknown-model", "gpt-4", "claude-3"]) +def test_on_llm_end_no_cost_invalid_model( + handler: GeminiCallbackHandler, model_name: str +) -> None: + """Test that unknown models result in zero cost.""" + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 1000, + "completion_tokens": 1000, + "total_tokens": 2000, + }, + "model_name": model_name, + }, + ) + handler.on_llm_end(response) + assert handler.total_cost == 0 + + +def test_on_llm_end_no_llm_output(handler: GeminiCallbackHandler) -> None: + """Test behavior when llm_output is None.""" + response = LLMResult( + generations=[], + llm_output=None, + ) + handler.on_llm_end(response) + # When llm_output is None, the handler returns early and doesn't increment + assert handler.successful_requests == 0 + assert handler.total_tokens == 0 + assert handler.prompt_tokens == 0 + assert handler.completion_tokens == 0 + assert handler.total_cost == 0 + + +def test_on_llm_end_no_token_usage(handler: GeminiCallbackHandler) -> None: + """Test behavior when token_usage is missing from llm_output.""" + response = LLMResult( + generations=[], + llm_output={ + "model_name": "gemini-2.5-pro", + }, + ) + handler.on_llm_end(response) + assert handler.successful_requests == 1 + assert handler.total_tokens == 0 + assert handler.prompt_tokens == 0 + assert handler.completion_tokens == 0 + assert handler.total_cost == 0 + + +def test_multiple_requests_accumulation(handler: GeminiCallbackHandler) -> None: + """Test that multiple requests accumulate correctly.""" + # First request + response1 = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + "model_name": "gemini-2.5-pro", + }, + ) + handler.on_llm_end(response1) + + # Second request + response2 = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + }, + "model_name": "gemini-1.5-pro", + }, + ) + handler.on_llm_end(response2) + + assert handler.successful_requests == 2 + assert handler.total_tokens == 450 + assert handler.prompt_tokens == 300 + assert handler.completion_tokens == 150 + assert handler.total_cost > 0 + + +def test_on_llm_start_no_op(handler: GeminiCallbackHandler) -> None: + """Test that on_llm_start does nothing (no-op).""" + # This should not raise any exceptions + handler.on_llm_start({}, ["test prompt"]) + + +def test_on_llm_new_token_no_op(handler: GeminiCallbackHandler) -> None: + """Test that on_llm_new_token does nothing (no-op).""" + # This should not raise any exceptions + handler.on_llm_new_token("test") + + +def test_handler_copy(handler: GeminiCallbackHandler) -> None: + """Test handler copy methods.""" + import copy + + # Test shallow copy + handler_copy = copy.copy(handler) + assert handler_copy is handler # Should return the same instance + + # Test deep copy + handler_deepcopy = copy.deepcopy(handler) + assert handler_deepcopy is handler # Should return the same instance + + +def test_on_retry_works(handler: GeminiCallbackHandler) -> None: + handler.on_retry(MagicMock(), run_id=uuid4())