Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/community/langchain_community/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,81 +11,84 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from langchain_community.callbacks.aim_callback import (
AimCallbackHandler,
)
from langchain_community.callbacks.argilla_callback import (
ArgillaCallbackHandler,
)
from langchain_community.callbacks.arize_callback import (
ArizeCallbackHandler,
)
from langchain_community.callbacks.arthur_callback import (
ArthurCallbackHandler,
)
from langchain_community.callbacks.clearml_callback import (
ClearMLCallbackHandler,
)
from langchain_community.callbacks.comet_ml_callback import (
CometCallbackHandler,
)
from langchain_community.callbacks.context_callback import (
ContextCallbackHandler,
)
from langchain_community.callbacks.fiddler_callback import (
FiddlerCallbackHandler,
)
from langchain_community.callbacks.flyte_callback import (
FlyteCallbackHandler,
)
from langchain_community.callbacks.human import (
HumanApprovalCallbackHandler,
)
from langchain_community.callbacks.infino_callback import (
InfinoCallbackHandler,
)
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler,
)
from langchain_community.callbacks.llmonitor_callback import (
LLMonitorCallbackHandler,
)
from langchain_community.callbacks.manager import (
get_openai_callback,
wandb_tracing_enabled,
)
from langchain_community.callbacks.mlflow_callback import (
MlflowCallbackHandler,
)
from langchain_community.callbacks.gemini_info import (
GeminiCallbackHandler,
)
from langchain_community.callbacks.openai_info import (
OpenAICallbackHandler,
)
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler,
)
from langchain_community.callbacks.sagemaker_callback import (
SageMakerCallbackHandler,
)
from langchain_community.callbacks.streamlit import (
LLMThoughtLabeler,
StreamlitCallbackHandler,
)
from langchain_community.callbacks.trubrics_callback import (
TrubricsCallbackHandler,
)
from langchain_community.callbacks.upstash_ratelimit_callback import (
UpstashRatelimitError,
UpstashRatelimitHandler, # noqa: F401
)
from langchain_community.callbacks.uptrain_callback import (
UpTrainCallbackHandler,
)
from langchain_community.callbacks.wandb_callback import (
WandbCallbackHandler,
)
from langchain_community.callbacks.whylabs_callback import (
WhyLabsCallbackHandler,
)

Check failure on line 91 in libs/community/langchain_community/callbacks/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/community / Python 3.11

Ruff (I001)

langchain_community/callbacks/__init__.py:14:5: I001 Import block is un-sorted or un-formatted


_module_lookup = {
Expand All @@ -103,6 +106,7 @@
"LLMThoughtLabeler": "langchain_community.callbacks.streamlit",
"LLMonitorCallbackHandler": "langchain_community.callbacks.llmonitor_callback",
"LabelStudioCallbackHandler": "langchain_community.callbacks.labelstudio_callback",
"GeminiCallbackHandler": "langchain_community.callbacks.gemini_info",
"MlflowCallbackHandler": "langchain_community.callbacks.mlflow_callback",
"OpenAICallbackHandler": "langchain_community.callbacks.openai_info",
"PromptLayerCallbackHandler": "langchain_community.callbacks.promptlayer_callback",
Expand Down Expand Up @@ -136,6 +140,7 @@
"ContextCallbackHandler",
"FiddlerCallbackHandler",
"FlyteCallbackHandler",
"GeminiCallbackHandler",
"HumanApprovalCallbackHandler",
"InfinoCallbackHandler",
"LLMThoughtLabeler",
Expand Down
222 changes: 222 additions & 0 deletions libs/community/langchain_community/callbacks/gemini_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import threading
from enum import Enum, auto
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult

MODEL_COST_PER_1M_TOKENS = {
"gemini-2.5-pro": 1.25,
"gemini-2.5-pro-completion": 10,
"gemini-2.5-flash": 0.3,
"gemini-2.5-flash-completion": 2.5,
"gemini-2.5-flash-lite": 0.1,
"gemini-2.5-flash-lite-completion": 0.4,
"gemini-2.0-flash": 0.1,
"gemini-2.0-flash-completion": 0.4,
"gemini-2.0-flash-lite": 0.075,
"gemini-2.0-flash-lite-completion": 0.3,
"gemini-1.5-pro": 1.25,
"gemini-1.5-pro-completion": 5,
"gemini-1.5-flash": 0.075,
"gemini-1.5-flash-completion": 0.3,
}
MODEL_COST_PER_1K_TOKENS = {k: v / 1000 for k, v in MODEL_COST_PER_1M_TOKENS.items()}


class TokenType(Enum):
"""Token type enum."""

PROMPT = auto()
PROMPT_CACHED = auto()
COMPLETION = auto()



def standardize_model_name(
model_name: str,
token_type: TokenType = TokenType.PROMPT,
) -> str:
"""Standardize the model name to a format that can be used in the Gemini API.

Args:
model_name: The name of the model to standardize.
token_type: The type of token, defaults to PROMPT.
"""
model_name = model_name.lower()
if token_type == TokenType.COMPLETION:
return model_name + "-completion"
else:
return model_name




def get_gemini_token_cost_for_model(
model_name: str,
num_tokens: int,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> float:
"""Get the cost in USD for a given model and number of tokens.

Args:
model_name: The name of the Gemini model to calculate cost for.
num_tokens: The number of tokens to calculate cost for.
is_completion: Whether the tokens are completion tokens.
If True, token_type will be set to TokenType.COMPLETION.
token_type: The type of token (prompt or completion).
Defaults to TokenType.PROMPT.

Returns:
The cost in USD for the specified number of tokens.

Raises:
ValueError: If the model name is not recognized as a valid Gemini model.
"""
if is_completion:
token_type = TokenType.COMPLETION
model_name = standardize_model_name(model_name, token_type=token_type)
if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError(
f"Unknown model: {model_name}. Please provide a valid Gemini model name. Known models are: "

Check failure on line 84 in libs/community/langchain_community/callbacks/gemini_info.py

View workflow job for this annotation

GitHub Actions / cd libs/community / Python 3.11

Ruff (E501)

langchain_community/callbacks/gemini_info.py:84:89: E501 Line too long (104 > 88)
+ ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
)
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)



class GeminiCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks Gemini info."""

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self.total_tokens = 0
self.prompt_tokens = 0
self.prompt_tokens_cached = 0
self.completion_tokens = 0
self.reasoning_tokens = 0
self.successful_requests = 0
self.total_cost = 0.0

def __repr__(self) -> str:
return f"""Tokens Used: {self.total_tokens}
\tPrompt Tokens: {self.prompt_tokens}
\tPrompt Cached Tokens: {self.prompt_tokens_cached}
\tCompletion Tokens: {self.completion_tokens}
\tReasoning Tokens: {self.reasoning_tokens}
Successful Requests: {self.successful_requests}
Total Cost (USD): ${self.total_cost}"""

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
try:
generation = response.generations[0][0]
except IndexError:
generation = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
response_metadata = message.response_metadata
else:
usage_metadata = None
response_metadata = None
except AttributeError:
usage_metadata = None
response_metadata = None
else:
usage_metadata = None
response_metadata = None

prompt_tokens_cached = 0
reasoning_tokens = 0

if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"]
prompt_tokens = usage_metadata["input_tokens"]
if response_model_name := (response_metadata or {}).get("model_name"):
model_name = standardize_model_name(response_model_name)
elif response.llm_output is None:
model_name = ""
else:
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)
if "cache_read" in usage_metadata.get("input_token_details", {}):
prompt_tokens_cached = usage_metadata["input_token_details"][
"cache_read"
]
if "reasoning" in usage_metadata.get("output_token_details", {}):
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"]
else:
if response.llm_output is None:
return None

if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None

# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)

if model_name in MODEL_COST_PER_1K_TOKENS:
uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached
uncached_prompt_cost = get_gemini_token_cost_for_model(
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT
)
cached_prompt_cost = get_gemini_token_cost_for_model(
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED
)
prompt_cost = uncached_prompt_cost + cached_prompt_cost
completion_cost = get_gemini_token_cost_for_model(
model_name, completion_tokens, token_type=TokenType.COMPLETION
)
else:
completion_cost = 0
prompt_cost = 0

# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.prompt_tokens_cached += prompt_tokens_cached
self.completion_tokens += completion_tokens
self.reasoning_tokens += reasoning_tokens
self.successful_requests += 1

def __copy__(self) -> "GeminiCallbackHandler":
"""Return a copy of the callback handler."""
return self

def __deepcopy__(self, memo: Any) -> "GeminiCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
23 changes: 23 additions & 0 deletions libs/community/langchain_community/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_community.callbacks.bedrock_anthropic_callback import (
BedrockAnthropicTokenUsageCallbackHandler,
)
from langchain_community.callbacks.gemini_info import GeminiCallbackHandler
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain_community.callbacks.tracers.comet import CometTracer
from langchain_community.callbacks.tracers.wandb import WandbTracer
Expand All @@ -25,6 +26,9 @@
bedrock_anthropic_callback_var: (ContextVar)[
Optional[BedrockAnthropicTokenUsageCallbackHandler]
] = ContextVar("bedrock_anthropic_callback", default=None)
gemini_callback_var: ContextVar[Optional[GeminiCallbackHandler]] = ContextVar(
"gemini_callback", default=None
)
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar(
"tracing_wandb_callback", default=None
)
Expand All @@ -34,6 +38,7 @@

register_configure_hook(openai_callback_var, True)
register_configure_hook(bedrock_anthropic_callback_var, True)
register_configure_hook(gemini_callback_var, True)
register_configure_hook(
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
)
Expand Down Expand Up @@ -81,6 +86,24 @@ def get_bedrock_anthropic_callback() -> Generator[
bedrock_anthropic_callback_var.set(None)


@contextmanager
def get_gemini_callback() -> Generator[GeminiCallbackHandler, None, None]:
"""Get the Gemini callback handler in a context manager.
which conveniently exposes token and cost information.

Returns:
GeminiCallbackHandler: The Gemini callback handler.

Example:
>>> with get_gemini_callback() as cb:
... # Use the Gemini callback handler
"""
cb = GeminiCallbackHandler()
gemini_callback_var.set(cb)
yield cb
gemini_callback_var.set(None)


@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
Expand Down
Loading
Loading