Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ async def get_response(

if hasattr(response, "usage"):
response_usage = response.usage

# Calculate cost using LiteLLM's completion_cost function if cost tracking is enabled. # noqa: E501
cost = None
if model_settings.track_cost:
try:
# Use LiteLLM's public API to calculate cost from the response.
cost = litellm.completion_cost(completion_response=response) # type: ignore[attr-defined]
except Exception:
# If cost calculation fails (e.g., unknown model), continue without cost.
pass

usage = (
Usage(
requests=1,
Expand All @@ -142,6 +153,7 @@ async def get_response(
)
or 0
),
cost=cost,
)
if response.usage
else Usage()
Expand Down Expand Up @@ -201,10 +213,67 @@ async def stream_response(

final_response: Response | None = None
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
yield chunk

# Intercept the response.completed event to calculate and attach cost.
if chunk.type == "response.completed":
final_response = chunk.response
# Calculate cost using LiteLLM's completion_cost function if enabled.
# Streaming responses don't include cost in _hidden_params, so we
# calculate it from the final token counts.
if model_settings.track_cost and final_response.usage:
try:
# Create a mock ModelResponse for cost calculation.
# Include token details (cached, reasoning) for accurate pricing.
from litellm.types.utils import (
Choices as LitellmChoices,
CompletionTokensDetailsWrapper,
Message as LitellmMessage,
ModelResponse as LitellmModelResponse,
PromptTokensDetailsWrapper,
Usage as LitellmUsage,
)

# Extract token details for accurate cost calculation.
cached_tokens = (
final_response.usage.input_tokens_details.cached_tokens
if final_response.usage.input_tokens_details
else 0
)
reasoning_tokens = (
final_response.usage.output_tokens_details.reasoning_tokens
if final_response.usage.output_tokens_details
else 0
)

mock_response = LitellmModelResponse(
choices=[
LitellmChoices(
index=0,
message=LitellmMessage(role="assistant", content=""),
)
],
usage=LitellmUsage(
prompt_tokens=final_response.usage.input_tokens,
completion_tokens=final_response.usage.output_tokens,
total_tokens=final_response.usage.total_tokens,
prompt_tokens_details=PromptTokensDetailsWrapper(
cached_tokens=cached_tokens
),
completion_tokens_details=CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens
),
),
model=self.model,
)
cost = litellm.completion_cost(completion_response=mock_response) # type: ignore[attr-defined]
# Attach cost as a custom attribute on the Response object so
# run.py can access it when creating the Usage object.
final_response._litellm_cost = cost # type: ignore[attr-defined]
except Exception:
# If cost calculation fails (e.g., unknown model), continue
# without cost.
pass

yield chunk

if tracing.include_data() and final_response:
span_generation.span_data.output = [final_response.model_dump()]
Expand Down
6 changes: 6 additions & 0 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class ModelSettings:
"""Whether to include usage chunk.
Only available for Chat Completions API."""

track_cost: bool = False
"""Whether to track and calculate cost for model calls.
When enabled, the SDK will populate `Usage.cost` with cost estimates.
Currently only supported for LiteLLM models. For other providers, cost will remain None.
Defaults to False."""

# TODO: revisit ResponseIncludable | str if ResponseIncludable covers more cases
# We've added str to support missing ones like
# "web_search_call.action.sources" etc.
Expand Down
4 changes: 4 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,9 @@ async def _run_single_turn_streamed(
prompt=prompt_config,
):
if isinstance(event, ResponseCompletedEvent):
# Extract cost if it was attached by LiteLLM model.
cost = getattr(event.response, "_litellm_cost", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I still hesitate to rely on LiteLLM's underscored property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about storing the cost in a _last_stream_cost attribute on the LitellmModel instance instead, and have run.py extract it from there?


usage = (
Usage(
requests=1,
Expand All @@ -1112,6 +1115,7 @@ async def _run_single_turn_streamed(
total_tokens=event.response.usage.total_tokens,
input_tokens_details=event.response.usage.input_tokens_details,
output_tokens_details=event.response.usage.output_tokens_details,
cost=cost,
)
if event.response.usage
else Usage()
Expand Down
9 changes: 9 additions & 0 deletions src/agents/usage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import field
from typing import Optional

from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -27,6 +28,10 @@ class Usage:
total_tokens: int = 0
"""Total tokens sent and received, across all requests."""

cost: Optional[float] = None
"""Total cost in USD for the requests. Only available for certain model providers
(e.g., LiteLLM). Will be None for models that don't provide cost information."""

def add(self, other: "Usage") -> None:
self.requests += other.requests if other.requests else 0
self.input_tokens += other.input_tokens if other.input_tokens else 0
Expand All @@ -41,3 +46,7 @@ def add(self, other: "Usage") -> None:
reasoning_tokens=self.output_tokens_details.reasoning_tokens
+ other.output_tokens_details.reasoning_tokens
)

# Aggregate cost if either has a value.
if self.cost is not None or other.cost is not None:
self.cost = (self.cost or 0.0) + (other.cost or 0.0)
194 changes: 194 additions & 0 deletions tests/models/test_litellm_cost_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Tests for LiteLLM cost tracking functionality."""

import litellm
import pytest
from litellm.types.utils import Choices, Message, ModelResponse, Usage

from agents.extensions.models.litellm_model import LitellmModel
from agents.model_settings import ModelSettings
from agents.models.interface import ModelTracing


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_extracted_when_track_cost_enabled(monkeypatch):
"""Test that cost is calculated using LiteLLM's completion_cost API when track_cost=True."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
return response

def fake_completion_cost(completion_response):
"""Mock litellm.completion_cost to return a test cost value."""
return 0.00042

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True), # Enable cost tracking
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost was calculated.
assert result.usage.cost == 0.00042


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_none_when_track_cost_disabled(monkeypatch):
"""Test that cost is None when track_cost=False (default)."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
return response

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
# Note: completion_cost should not be called when track_cost=False

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=False), # Disabled (default)
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is None when tracking is disabled.
assert result.usage.cost is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_none_when_not_provided(monkeypatch):
"""Test that cost is None when completion_cost raises an exception."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
return response

def fake_completion_cost(completion_response):
"""Mock completion_cost to raise an exception (e.g., unknown model)."""
raise Exception("Model not found in pricing database")

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is None when completion_cost fails.
assert result.usage.cost is None


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_zero_when_completion_cost_returns_zero(monkeypatch):
"""Test that cost is 0 when completion_cost returns 0 (e.g., free model)."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
return response

def fake_completion_cost(completion_response):
"""Mock completion_cost to return 0 (e.g., free model)."""
return 0.0

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify that cost is 0 for free models.
assert result.usage.cost == 0.0


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_cost_extraction_preserves_other_usage_fields(monkeypatch):
"""Test that cost calculation doesn't affect other usage fields."""

async def fake_acompletion(model, messages=None, **kwargs):
msg = Message(role="assistant", content="Test response")
choice = Choices(index=0, message=msg)
response = ModelResponse(
choices=[choice],
usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150),
)
return response

def fake_completion_cost(completion_response):
"""Mock litellm.completion_cost to return a test cost value."""
return 0.001

monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)

model = LitellmModel(model="test-model", api_key="test-key")
result = await model.get_response(
system_instructions=None,
input=[],
model_settings=ModelSettings(track_cost=True),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
)

# Verify all usage fields are correct.
assert result.usage.input_tokens == 100
assert result.usage.output_tokens == 50
assert result.usage.total_tokens == 150
assert result.usage.cost == 0.001
assert result.usage.requests == 1
Loading