diff --git a/src/agents-api/agents_api/clients/litellm.py b/src/agents-api/agents_api/clients/litellm.py index c9bc89042..de2a8d339 100644 --- a/src/agents-api/agents_api/clients/litellm.py +++ b/src/agents-api/agents_api/clients/litellm.py @@ -15,7 +15,8 @@ ) from ..common.utils.llm_providers import get_api_key_env_var_name from ..common.utils.secrets import get_secret_by_name -from ..common.utils.usage import track_embedding_usage, track_usage +from ..common.utils.usage import track_embedding_usage +from ..common.utils.usage_tracker import track_completion_usage from ..env import ( embedding_dimensions, embedding_model_id, @@ -104,12 +105,12 @@ async def acompletion( response = patch_litellm_response(model_response) - # Track usage in database if we have a user ID (which should be the developer ID) + # Track usage if we have a user ID (which should be the developer ID) user = settings.get("user") if user and isinstance(response, ModelResponse): try: model = response.model - await track_usage( + await track_completion_usage( developer_id=UUID(user), model=model, messages=messages, diff --git a/src/agents-api/agents_api/common/utils/usage_tracker.py b/src/agents-api/agents_api/common/utils/usage_tracker.py new file mode 100644 index 000000000..afd804008 --- /dev/null +++ b/src/agents-api/agents_api/common/utils/usage_tracker.py @@ -0,0 +1,122 @@ +""" +Centralized usage tracking utilities for LLM API calls. +Handles both Prometheus metrics and database tracking. +""" + +from typing import Any +from uuid import UUID + +from beartype import beartype +from litellm.utils import ModelResponse, Choices, Message +from prometheus_client import Counter + +from .usage import track_usage + +# Prometheus metrics +total_tokens_per_user = Counter( + "total_tokens_per_user", + "Total token count per user", + labelnames=("developer_id",), +) + + +@beartype +async def track_completion_usage( + *, + developer_id: UUID, + model: str, + messages: list[dict], + response: ModelResponse, + custom_api_used: bool = False, + metadata: Optional[dict[str, Any]] = None, + connection_pool: Any = None, +) -> None: + """ + Tracks usage for completion responses (both streaming and non-streaming). + + Args: + developer_id: The developer ID for usage tracking + model: The model name used for the response + messages: The original messages sent to the model + response: The model response + custom_api_used: Whether a custom API key was used + metadata: Additional metadata for tracking + connection_pool: Connection pool for testing purposes + """ + # Track Prometheus metrics + if response.usage and response.usage.total_tokens > 0: + total_tokens_per_user.labels(str(developer_id)).inc( + amount=response.usage.total_tokens + ) + + # Track usage in database + await track_usage( + developer_id=developer_id, + model=model, + messages=messages, + response=response, + custom_api_used=custom_api_used, + metadata=metadata, + connection_pool=connection_pool, + ) + + +@beartype +async def track_streaming_usage( + *, + developer_id: UUID, + model: str, + messages: list[dict], + usage_data: dict[str, Any] | None, + collected_output: list[dict], + response_id: str, + custom_api_used: bool = False, + metadata: dict[str, Any] = None, + connection_pool: Any = None, +) -> None: + """ + Tracks usage for streaming responses. + + Args: + developer_id: The developer ID for usage tracking + model: The model name used for the response + messages: The original messages sent to the model + usage_data: Usage data from the streaming response + collected_output: The complete collected output from streaming + response_id: The response ID + custom_api_used: Whether a custom API key was used + metadata: Additional metadata for tracking + connection_pool: Connection pool for testing purposes + """ + # Track Prometheus metrics if usage data is available + if usage_data and usage_data.get("total_tokens", 0) > 0: + total_tokens_per_user.labels(str(developer_id)).inc( + amount=usage_data.get("total_tokens", 0) + ) + + # Only track usage in database if we have collected output + if not collected_output: + return + + # Track usage in database + await track_usage( + developer_id=developer_id, + model=model, + messages=messages, + response=ModelResponse( + id=response_id, + choices=[ + Choices( + message=Message( + content=choice.get("content", ""), + tool_calls=choice.get("tool_calls"), + ), + ) + for choice in collected_output + ], + usage=usage_data, + ), + custom_api_used=custom_api_used, + metadata=metadata, + connection_pool=connection_pool, + ) \ No newline at end of file diff --git a/src/agents-api/agents_api/routers/responses/create_response.py b/src/agents-api/agents_api/routers/responses/create_response.py index a195aeec2..64327aa61 100644 --- a/src/agents-api/agents_api/routers/responses/create_response.py +++ b/src/agents-api/agents_api/routers/responses/create_response.py @@ -24,7 +24,7 @@ convert_chat_response_to_response, convert_create_response, ) -from ..sessions.metrics import total_tokens_per_user + from ..sessions.render import render_chat_input from .router import router @@ -315,9 +315,7 @@ async def create_response( choices=[choice.model_dump() for choice in model_response.choices], ) - total_tokens_per_user.labels(str(developer.id)).inc( - amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0, - ) + # End chat function return convert_chat_response_to_response( diff --git a/src/agents-api/agents_api/routers/sessions/chat.py b/src/agents-api/agents_api/routers/sessions/chat.py index ba80b3aa2..866982e70 100644 --- a/src/agents-api/agents_api/routers/sessions/chat.py +++ b/src/agents-api/agents_api/routers/sessions/chat.py @@ -4,7 +4,7 @@ from fastapi import BackgroundTasks, Depends, Header from fastapi.responses import StreamingResponse -from litellm.utils import Choices, Message, ModelResponse +from litellm.utils import ModelResponse from starlette.status import HTTP_201_CREATED from uuid_extensions import uuid7 @@ -18,10 +18,9 @@ from ...clients import litellm from ...common.protocol.developers import Developer from ...common.utils.datetime import utcnow -from ...common.utils.usage import track_usage +from ...common.utils.usage_tracker import track_streaming_usage from ...dependencies.developer_id import get_developer_data from ...queries.entries.create_entries import create_entries -from .metrics import total_tokens_per_user from .render import render_chat_input from .router import router @@ -118,30 +117,14 @@ async def stream_chat_response( # Forward the chunk as a proper ChunkChatResponse yield f"data: {chunk_response.model_dump_json()}\n\n" - # Track token usage with Prometheus metrics if available - if usage_data and usage_data.get("total_tokens", 0) > 0: - total_tokens_per_user.labels(str(developer_id)).inc( - amount=usage_data.get("total_tokens", 0) - ) - - # Track usage in database - await track_usage( + # Track usage using centralized tracker + await track_streaming_usage( developer_id=developer_id, model=model, messages=messages or [], - response=ModelResponse( - id=str(response_id), - choices=[ - Choices( - message=Message( - content=choice.get("content", ""), - tool_calls=choice.get("tool_calls"), - ), - ) - for choice in collected_output - ], - usage=usage_data, - ), + usage_data=usage_data, + collected_output=collected_output, + response_id=str(response_id), custom_api_used=custom_api_key_used, metadata={ "tags": developer_tags or [], @@ -300,8 +283,4 @@ async def chat( choices=[choice.model_dump() for choice in model_response.choices], ) - total_tokens_per_user.labels(str(developer.id)).inc( - amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0, - ) - return chat_response diff --git a/src/agents-api/agents_api/routers/sessions/metrics.py b/src/agents-api/agents_api/routers/sessions/metrics.py deleted file mode 100644 index 5c432e4e7..000000000 --- a/src/agents-api/agents_api/routers/sessions/metrics.py +++ /dev/null @@ -1,7 +0,0 @@ -from prometheus_client import Counter - -total_tokens_per_user = Counter( - "total_tokens_per_user", - "Total token count per user", - labelnames=("developer_id",), -)