Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ async def prompt_step(context: StepContext) -> StepOutcome:
)
passed_settings.update(passed_settings.pop("settings", {}) or {})
passed_settings["user"] = str(context.execution_input.developer_id)
# AIDEV-NOTE: 1472:: Add execution context to passed_settings for usage tracking
passed_settings["execution_id"] = str(context.execution_input.execution.execution_id)
if (
hasattr(context.execution_input.execution, "session_id")
and context.execution_input.execution.session_id
):
passed_settings["session_id"] = str(context.execution_input.execution.session_id)
if (
hasattr(context, "cursor")
and hasattr(context.cursor, "transition_id")
and context.cursor.transition_id
):
passed_settings["transition_id"] = str(context.cursor.transition_id)

if not passed_settings.get("tools"):
passed_settings.pop("tool_choice", None)
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,18 @@ async def acompletion(
if user and isinstance(response, ModelResponse):
try:
model = response.model
# AIDEV-NOTE: 1472:: Extract context fields from kwargs/settings if available and pass to track_usage
await track_usage(
developer_id=UUID(user),
model=model,
messages=messages,
response=response,
custom_api_used=custom_api_key is not None,
metadata={"tags": kwargs.get("tags", [])},
session_id=kwargs.get("session_id"),
execution_id=kwargs.get("execution_id"),
transition_id=kwargs.get("transition_id"),
entry_id=kwargs.get("entry_id"),
)
except Exception as e:
# Log error but don't fail the request if usage tracking fails
Expand Down
75 changes: 75 additions & 0 deletions agents-api/agents_api/common/utils/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,54 @@
from ...queries.usage.create_usage_record import create_usage_record


@beartype
def extract_provider_from_model(model: str) -> str | None:
"""
Extract the provider from a model name.

Args:
model (str): The model name (e.g., "gpt-4", "claude-3-sonnet", "openai/gpt-4")

Returns:
str | None: The provider name (e.g., "openai", "anthropic") or None if unknown
Comment on lines +15 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about the functionality of this function. Litellm in general with a self-hosted proxy acts very weirdly towards providers. A better way I think is to determine the provider by the api key provided in the litellm-config.

"""
# Handle prefixed models (e.g., "openai/gpt-4")
if "/" in model:
provider_prefix = model.split("/")[0].lower()
# Map some common prefixes
provider_mapping = {
"openai": "openai",
"anthropic": "anthropic",
"google": "google",
"meta-llama": "meta",
"mistralai": "mistral",
"openrouter": "openrouter",
}
if provider_prefix in provider_mapping:
return provider_mapping[provider_prefix]

# Detect based on model name patterns
model_lower = model.lower()

if any(
pattern in model_lower
for pattern in ["gpt-", "o1-", "text-davinci", "text-ada", "text-babbage", "text-curie"]
):
return "openai"
if any(pattern in model_lower for pattern in ["claude-", "claude"]):
return "anthropic"
if any(pattern in model_lower for pattern in ["gemini-", "palm-", "bison", "gecko"]):
return "google"
if any(pattern in model_lower for pattern in ["llama-", "llama2", "llama3", "meta-llama"]):
return "meta"
if any(pattern in model_lower for pattern in ["mistral-", "mixtral-"]):
return "mistral"
if any(pattern in model_lower for pattern in ["qwen-", "qwen2"]):
return "qwen"

return None


@beartype
async def track_usage(
*,
Expand All @@ -21,6 +69,10 @@ async def track_usage(
custom_api_used: bool = False,
metadata: dict[str, Any] = {},
connection_pool: Any = None, # This is for testing purposes
session_id: UUID | None = None,
execution_id: UUID | None = None,
transition_id: UUID | None = None,
entry_id: UUID | None = None,
) -> None:
"""
Tracks token usage and costs for an LLM API call.
Expand All @@ -32,6 +84,10 @@ async def track_usage(
response (ModelResponse): The response from the LLM API call.
custom_api_used (bool): Whether a custom API key was used.
metadata (dict): Additional metadata about the usage.
session_id (UUID | None): The session that generated this usage.
execution_id (UUID | None): The task execution that generated this usage.
transition_id (UUID | None): The specific transition step that generated this usage.
entry_id (UUID | None): The chat entry that generated this usage.

Returns:
None
Expand Down Expand Up @@ -62,7 +118,11 @@ async def track_usage(
# Map the model name to the actual model name
actual_model = model

# Extract provider from model name
provider = extract_provider_from_model(actual_model)

# Create usage record
# AIDEV-NOTE: 1472:: Updated to pass new context fields and provider for better tracking
await create_usage_record(
developer_id=developer_id,
model=actual_model,
Expand All @@ -71,8 +131,23 @@ async def track_usage(
custom_api_used=custom_api_used,
metadata={
"request_id": response.id if hasattr(response, "id") else None,
"input_messages": messages,
"output_content": [
choice.message.content
for choice in response.choices
if hasattr(choice, "message")
and choice.message
and hasattr(choice.message, "content")
]
Comment on lines +134 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding all messages to the metadata adds a huge unnecessary memory overhead to the usage table.

if response.choices
else None,
**metadata,
},
session_id=session_id,
execution_id=execution_id,
transition_id=transition_id,
entry_id=entry_id,
provider=provider,
connection_pool=connection_pool,
)

Expand Down
31 changes: 29 additions & 2 deletions agents-api/agents_api/queries/usage/create_usage_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
sum(combined_output_costs) / len(combined_output_costs) if combined_output_costs else 0
)

# AIDEV-NOTE: 1472:: Updated query to include new reference fields and provider for better tracking
# Define the raw SQL query
usage_query = """
INSERT INTO usage (
Expand All @@ -99,7 +100,12 @@
cost,
estimated,
custom_api_used,
metadata
metadata,
session_id,
execution_id,
transition_id,
entry_id,
provider
)
VALUES (
$1, -- developer_id
Expand All @@ -109,7 +115,12 @@
$5, -- cost
$6, -- estimated
$7, -- custom_api_used
$8 -- metadata
$8, -- metadata
$9, -- session_id
$10, -- execution_id
$11, -- transition_id
$12, -- entry_id
$13 -- provider
)
RETURNING *;
"""
Expand All @@ -128,6 +139,11 @@ async def create_usage_record(
custom_api_used: bool = False,
estimated: bool = False,
metadata: dict[str, Any] | None = None,
session_id: UUID | None = None,
execution_id: UUID | None = None,
transition_id: UUID | None = None,
entry_id: UUID | None = None,
provider: str | None = None,
Comment on lines 139 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of variables, better to create a model in typespec to clean up the parameters.

) -> tuple[str, list]:
"""
Creates a usage record to track token usage and costs.
Expand All @@ -140,6 +156,11 @@ async def create_usage_record(
custom_api_used (bool): Whether a custom API key was used.
estimated (bool): Whether the token count is estimated.
metadata (dict | None): Additional metadata about the usage.
session_id (UUID | None): The session that generated this usage.
execution_id (UUID | None): The task execution that generated this usage.
transition_id (UUID | None): The specific transition step that generated this usage.
entry_id (UUID | None): The chat entry that generated this usage.
provider (str | None): The actual LLM provider used (e.g., openai, anthropic, google).

Returns:
tuple[str, list]: SQL query and parameters for creating the usage record.
Expand Down Expand Up @@ -168,6 +189,7 @@ async def create_usage_record(
)
print(f"No fallback pricing found for model {model}, using avg costs: {total_cost}")

# AIDEV-NOTE: 1472:: Updated to include new reference fields and provider in params list
params = [
developer_id,
model,
Expand All @@ -177,6 +199,11 @@ async def create_usage_record(
estimated,
custom_api_used,
metadata or {},
session_id,
execution_id,
transition_id,
entry_id,
provider,
]

return (
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ async def stream_chat_response(
)

# Track usage in database
# AIDEV-NOTE: 1472:: Updated to pass session_id for better usage tracking
await track_usage(
developer_id=developer_id,
model=model,
Expand All @@ -147,6 +148,7 @@ async def stream_chat_response(
"tags": developer_tags or [],
"streaming": True,
},
session_id=session_id,
connection_pool=connection_pool,
)

Expand Down Expand Up @@ -214,12 +216,14 @@ async def chat(
)

# Prepare parameters for LiteLLM
# AIDEV-NOTE: 1472:: session_id added to params for usage tracking analytics
params = {
"messages": messages,
"tools": formatted_tools or None,
"user": str(developer.id),
"tags": developer.tags,
"custom_api_key": x_custom_api_key,
"session_id": session_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

entry_id isn't being passed.

}

# Set streaming parameter based on chat_input.stream
Expand Down
21 changes: 21 additions & 0 deletions memory-store/migrations/000042_usage_enhancement.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- AIDEV-NOTE: 1472:: Rollback migration to remove usage enhancement columns
-- This migration removes the reference fields and provider column added in the up migration

BEGIN;

-- Drop indexes first
DROP INDEX IF EXISTS idx_usage_session_created;
DROP INDEX IF EXISTS idx_usage_execution_created;
DROP INDEX IF EXISTS idx_usage_provider;

-- AIDEV-NOTE: 1472:: No foreign key constraints to drop (hypertables can't have FK constraints to other hypertables)

-- Drop columns
ALTER TABLE usage
DROP COLUMN IF EXISTS session_id,
DROP COLUMN IF EXISTS execution_id,
DROP COLUMN IF EXISTS transition_id,
DROP COLUMN IF EXISTS entry_id,
DROP COLUMN IF EXISTS provider;

COMMIT;
38 changes: 38 additions & 0 deletions memory-store/migrations/000042_usage_enhancement.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
-- AIDEV-NOTE: 1472:: Migration to add reference fields and provider to usage table
-- This migration adds context fields to track which session/execution/transition/entry
-- generated the usage, along with the actual provider used

BEGIN;

-- Add new columns to usage table
ALTER TABLE usage
ADD COLUMN session_id UUID NULL,
ADD COLUMN execution_id UUID NULL,
ADD COLUMN transition_id UUID NULL,
ADD COLUMN entry_id UUID NULL,
ADD COLUMN provider TEXT NULL;
Comment on lines +9 to +13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

entry_id and transition_id would be enough, though it's not that bad to add them as fields as well.


-- AIDEV-NOTE: 1472:: TimescaleDB hypertables cannot have foreign key constraints to other hypertables
-- We rely on application-level referential integrity for these fields

-- Create indexes for efficient querying
CREATE INDEX idx_usage_session_created
ON usage(developer_id, session_id, created_at DESC)
WHERE session_id IS NOT NULL;

CREATE INDEX idx_usage_execution_created
ON usage(developer_id, execution_id, created_at DESC)
WHERE execution_id IS NOT NULL;

CREATE INDEX idx_usage_provider
ON usage(provider)
WHERE provider IS NOT NULL;

-- Add comment to explain the purpose of new fields
COMMENT ON COLUMN usage.session_id IS 'Reference to the session that generated this usage';
COMMENT ON COLUMN usage.execution_id IS 'Reference to the task execution that generated this usage';
COMMENT ON COLUMN usage.transition_id IS 'Reference to the specific transition step that generated this usage';
COMMENT ON COLUMN usage.entry_id IS 'Reference to the chat entry that generated this usage';
COMMENT ON COLUMN usage.provider IS 'The actual LLM provider used (e.g., openai, anthropic, google)';

COMMIT;
1 change: 1 addition & 0 deletions typespec/chat/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ model InputChatSettings extends InputDefaultChatSettings {
}

/** Usage statistics for the completion request */
// AIDEV-NOTE: 1472:: This model is for API responses. Internal usage tracking with context fields is handled at the database level
model CompletionUsage {
/** Number of tokens in the generated completion */
@visibility("read")
Expand Down
Loading