Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ async def get_response(

if hasattr(response, "usage"):
response_usage = response.usage

# Extract cost from LiteLLM's hidden params if cost tracking is enabled.
cost = None
if model_settings.track_cost:
if hasattr(response, "_hidden_params") and isinstance(
response._hidden_params, dict
):
cost = response._hidden_params.get("response_cost")

usage = (
Usage(
requests=1,
Expand All @@ -142,6 +151,7 @@ async def get_response(
)
or 0
),
cost=cost,
)
if response.usage
else Usage()
Expand Down Expand Up @@ -201,10 +211,67 @@ async def stream_response(

final_response: Response | None = None
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
yield chunk

# Intercept the response.completed event to calculate and attach cost.
if chunk.type == "response.completed":
final_response = chunk.response
# Calculate cost using LiteLLM's completion_cost function if enabled.
# Streaming responses don't include cost in _hidden_params, so we
# calculate it from the final token counts.
if model_settings.track_cost and final_response.usage:
try:
# Create a mock ModelResponse for cost calculation.
# Include token details (cached, reasoning) for accurate pricing.
from litellm.types.utils import (
Choices as LitellmChoices,
CompletionTokensDetailsWrapper,
Message as LitellmMessage,
ModelResponse as LitellmModelResponse,
PromptTokensDetailsWrapper,
Usage as LitellmUsage,
)

# Extract token details for accurate cost calculation.
cached_tokens = (
final_response.usage.input_tokens_details.cached_tokens
if final_response.usage.input_tokens_details
else 0
)
reasoning_tokens = (
final_response.usage.output_tokens_details.reasoning_tokens
if final_response.usage.output_tokens_details
else 0
)

mock_response = LitellmModelResponse(
choices=[
LitellmChoices(
index=0,
message=LitellmMessage(role="assistant", content=""),
)
],
usage=LitellmUsage(
prompt_tokens=final_response.usage.input_tokens,
completion_tokens=final_response.usage.output_tokens,
total_tokens=final_response.usage.total_tokens,
prompt_tokens_details=PromptTokensDetailsWrapper(
cached_tokens=cached_tokens
),
completion_tokens_details=CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens
),
),
model=self.model,
)
cost = litellm.completion_cost(completion_response=mock_response) # type: ignore[attr-defined]
# Attach cost as a custom attribute on the Response object so
# run.py can access it when creating the Usage object.
final_response._litellm_cost = cost # type: ignore[attr-defined]
except Exception:
# If cost calculation fails (e.g., unknown model), continue
# without cost.
pass

yield chunk

if tracing.include_data() and final_response:
span_generation.span_data.output = [final_response.model_dump()]
Expand Down
6 changes: 6 additions & 0 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class ModelSettings:
"""Whether to include usage chunk.
Only available for Chat Completions API."""

track_cost: bool = False
"""Whether to track and calculate cost for model calls.
When enabled, the SDK will populate `Usage.cost` with cost estimates.
Currently only supported for LiteLLM models. For other providers, cost will remain None.
Defaults to False."""

# TODO: revisit ResponseIncludable | str if ResponseIncludable covers more cases
# We've added str to support missing ones like
# "web_search_call.action.sources" etc.
Expand Down
4 changes: 4 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,9 @@ async def _run_single_turn_streamed(
prompt=prompt_config,
):
if isinstance(event, ResponseCompletedEvent):
# Extract cost if it was attached by LiteLLM model.
cost = getattr(event.response, "_litellm_cost", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I still hesitate to rely on LiteLLM's underscored property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about storing the cost in a _last_stream_cost attribute on the LitellmModel instance instead, and have run.py extract it from there?


usage = (
Usage(
requests=1,
Expand All @@ -1112,6 +1115,7 @@ async def _run_single_turn_streamed(
total_tokens=event.response.usage.total_tokens,
input_tokens_details=event.response.usage.input_tokens_details,
output_tokens_details=event.response.usage.output_tokens_details,
cost=cost,
)
if event.response.usage
else Usage()
Expand Down
9 changes: 9 additions & 0 deletions src/agents/usage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import field
from typing import Optional

from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -27,6 +28,10 @@ class Usage:
total_tokens: int = 0
"""Total tokens sent and received, across all requests."""

cost: Optional[float] = None
"""Total cost in USD for the requests. Only available for certain model providers
(e.g., LiteLLM). Will be None for models that don't provide cost information."""

def add(self, other: "Usage") -> None:
self.requests += other.requests if other.requests else 0
self.input_tokens += other.input_tokens if other.input_tokens else 0
Expand All @@ -41,3 +46,7 @@ def add(self, other: "Usage") -> None:
reasoning_tokens=self.output_tokens_details.reasoning_tokens
+ other.output_tokens_details.reasoning_tokens
)

# Aggregate cost if either has a value.
if self.cost is not None or other.cost is not None:
self.cost = (self.cost or 0.0) + (other.cost or 0.0)
181 changes: 181 additions & 0 deletions tests/models/test_litellm_cost_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""Tests for LiteLLM cost tracking functionality."""

import litellm
import pytest
from litellm.types.utils import Choices, Message, ModelResponse, Usage

from agents.extensions.models.litellm_model import LitellmModel
from agents.model_settings import ModelSettings
from agents.models.interface import ModelTracing


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_extracted_when_track_cost_enabled(monkeypatch):
"""Test that cost is extracted from LiteLLM response when track_cost=True."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
# Simulate LiteLLM's hidden params with cost.
response._hidden_params = {"response_cost": 0.00042}
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True), # Enable cost tracking
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost was extracted.
assert result.usage.cost == 0.00042


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_none_when_track_cost_disabled(monkeypatch):
"""Test that cost is None when track_cost=False (default)."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
# Even if LiteLLM provides cost, it should be ignored.
response._hidden_params = {"response_cost": 0.00042}
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=False), # Disabled (default)
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is None when tracking is disabled.
assert result.usage.cost is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_none_when_not_provided(monkeypatch):
"""Test that cost is None when LiteLLM doesn't provide it."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
# No _hidden_params or no cost in hidden params.
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is None when not provided.
assert result.usage.cost is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_with_empty_hidden_params(monkeypatch):
"""Test that cost extraction handles empty _hidden_params gracefully."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
# Empty hidden params.
response._hidden_params = {}
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is None with empty hidden params.
assert result.usage.cost is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_extraction_preserves_other_usage_fields(monkeypatch):
"""Test that cost extraction doesn't affect other usage fields."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150),
)
response._hidden_params = {"response_cost": 0.001}
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify all usage fields are correct.
assert result.usage.input_tokens == 100
assert result.usage.output_tokens == 50
assert result.usage.total_tokens == 150
assert result.usage.cost == 0.001
assert result.usage.requests == 1
Loading