diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index a4c8da3ab..43b85a8bf 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -124,6 +124,17 @@ async def get_response( if hasattr(response, "usage"): response_usage = response.usage + + # Calculate cost using LiteLLM's completion_cost function if cost tracking is enabled. # noqa: E501 + cost = None + if model_settings.track_cost: + try: + # Use LiteLLM's public API to calculate cost from the response. + cost = litellm.completion_cost(completion_response=response) # type: ignore[attr-defined] + except Exception: + # If cost calculation fails (e.g., unknown model), continue without cost. + pass + usage = ( Usage( requests=1, @@ -142,6 +153,7 @@ async def get_response( ) or 0 ), + cost=cost, ) if response.usage else Usage() @@ -201,10 +213,67 @@ async def stream_response( final_response: Response | None = None async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): - yield chunk - + # Intercept the response.completed event to calculate and attach cost. if chunk.type == "response.completed": final_response = chunk.response + # Calculate cost using LiteLLM's completion_cost function if enabled. + # Streaming responses don't include cost in _hidden_params, so we + # calculate it from the final token counts. + if model_settings.track_cost and final_response.usage: + try: + # Create a mock ModelResponse for cost calculation. + # Include token details (cached, reasoning) for accurate pricing. + from litellm.types.utils import ( + Choices as LitellmChoices, + CompletionTokensDetailsWrapper, + Message as LitellmMessage, + ModelResponse as LitellmModelResponse, + PromptTokensDetailsWrapper, + Usage as LitellmUsage, + ) + + # Extract token details for accurate cost calculation. + cached_tokens = ( + final_response.usage.input_tokens_details.cached_tokens + if final_response.usage.input_tokens_details + else 0 + ) + reasoning_tokens = ( + final_response.usage.output_tokens_details.reasoning_tokens + if final_response.usage.output_tokens_details + else 0 + ) + + mock_response = LitellmModelResponse( + choices=[ + LitellmChoices( + index=0, + message=LitellmMessage(role="assistant", content=""), + ) + ], + usage=LitellmUsage( + prompt_tokens=final_response.usage.input_tokens, + completion_tokens=final_response.usage.output_tokens, + total_tokens=final_response.usage.total_tokens, + prompt_tokens_details=PromptTokensDetailsWrapper( + cached_tokens=cached_tokens + ), + completion_tokens_details=CompletionTokensDetailsWrapper( + reasoning_tokens=reasoning_tokens + ), + ), + model=self.model, + ) + cost = litellm.completion_cost(completion_response=mock_response) # type: ignore[attr-defined] + # Attach cost as a custom attribute on the Response object so + # run.py can access it when creating the Usage object. + final_response._litellm_cost = cost # type: ignore[attr-defined] + except Exception: + # If cost calculation fails (e.g., unknown model), continue + # without cost. + pass + + yield chunk if tracing.include_data() and final_response: span_generation.span_data.output = [final_response.model_dump()] diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 6a3dbd04c..332a7b850 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -120,6 +120,12 @@ class ModelSettings: """Whether to include usage chunk. Only available for Chat Completions API.""" + track_cost: bool | None = None + """Whether to track and calculate cost for model calls. + When enabled, the SDK will populate `Usage.cost` with cost estimates. + Currently only supported for LiteLLM models. For other providers, cost will remain None. + If not provided (i.e., set to None), cost tracking is disabled (defaults to False).""" + # TODO: revisit ResponseIncludable | str if ResponseIncludable covers more cases # We've added str to support missing ones like # "web_search_call.action.sources" etc. diff --git a/src/agents/run.py b/src/agents/run.py index 52d395a13..5b4f5e558 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1179,6 +1179,9 @@ async def _run_single_turn_streamed( prompt=prompt_config, ): if isinstance(event, ResponseCompletedEvent): + # Extract cost if it was attached by LiteLLM model. + cost = getattr(event.response, "_litellm_cost", None) + usage = ( Usage( requests=1, @@ -1187,6 +1190,7 @@ async def _run_single_turn_streamed( total_tokens=event.response.usage.total_tokens, input_tokens_details=event.response.usage.input_tokens_details, output_tokens_details=event.response.usage.output_tokens_details, + cost=cost, ) if event.response.usage else Usage() diff --git a/src/agents/usage.py b/src/agents/usage.py index 3639cf944..d843779ee 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,5 @@ from dataclasses import field +from typing import Optional from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from pydantic.dataclasses import dataclass @@ -27,6 +28,10 @@ class Usage: total_tokens: int = 0 """Total tokens sent and received, across all requests.""" + cost: Optional[float] = None + """Total cost in USD for the requests. Only available for certain model providers + (e.g., LiteLLM). Will be None for models that don't provide cost information.""" + def add(self, other: "Usage") -> None: self.requests += other.requests if other.requests else 0 self.input_tokens += other.input_tokens if other.input_tokens else 0 @@ -41,3 +46,7 @@ def add(self, other: "Usage") -> None: reasoning_tokens=self.output_tokens_details.reasoning_tokens + other.output_tokens_details.reasoning_tokens ) + + # Aggregate cost if either has a value. + if self.cost is not None or other.cost is not None: + self.cost = (self.cost or 0.0) + (other.cost or 0.0) diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index f099a1a31..5b4bcbbb6 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -57,6 +57,7 @@ def test_all_fields_serialization() -> None: metadata={"foo": "bar"}, store=False, include_usage=False, + track_cost=False, response_include=["reasoning.encrypted_content"], top_logprobs=1, verbosity="low", diff --git a/tests/models/test_litellm_cost_tracking.py b/tests/models/test_litellm_cost_tracking.py new file mode 100644 index 000000000..7f9b5f50f --- /dev/null +++ b/tests/models/test_litellm_cost_tracking.py @@ -0,0 +1,220 @@ +"""Tests for LiteLLM cost tracking functionality.""" + +import litellm +import pytest +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from agents.extensions.models.litellm_model import LitellmModel +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_cost_extracted_when_track_cost_enabled(monkeypatch): + """Test that cost is calculated using LiteLLM's completion_cost API when track_cost=True.""" + + async def fake_acompletion(model, messages=None, **kwargs): + msg = Message(role="assistant", content="Test response") + choice = Choices(index=0, message=msg) + response = ModelResponse( + choices=[choice], + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + return response + + def fake_completion_cost(completion_response): + """Mock litellm.completion_cost to return a test cost value.""" + return 0.00042 + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) + + model = LitellmModel(model="test-model", api_key="test-key") + result = await model.get_response( + system_instructions=None, + input=[], + model_settings=ModelSettings(track_cost=True), # Enable cost tracking + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + # Verify that cost was calculated. + assert result.usage.cost == 0.00042 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_cost_none_when_track_cost_disabled(monkeypatch): + """Test that cost is None when track_cost is not set (defaults to None/False).""" + + async def fake_acompletion(model, messages=None, **kwargs): + msg = Message(role="assistant", content="Test response") + choice = Choices(index=0, message=msg) + response = ModelResponse( + choices=[choice], + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + return response + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + # Note: completion_cost should not be called when track_cost is None (default) + + model = LitellmModel(model="test-model", api_key="test-key") + result = await model.get_response( + system_instructions=None, + input=[], + model_settings=ModelSettings(), # track_cost defaults to None (disabled) + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + # Verify that cost is None when tracking is disabled. + assert result.usage.cost is None + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_cost_none_when_not_provided(monkeypatch): + """Test that cost is None when completion_cost raises an exception.""" + + async def fake_acompletion(model, messages=None, **kwargs): + msg = Message(role="assistant", content="Test response") + choice = Choices(index=0, message=msg) + response = ModelResponse( + choices=[choice], + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + return response + + def fake_completion_cost(completion_response): + """Mock completion_cost to raise an exception (e.g., unknown model).""" + raise Exception("Model not found in pricing database") + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) + + model = LitellmModel(model="test-model", api_key="test-key") + result = await model.get_response( + system_instructions=None, + input=[], + model_settings=ModelSettings(track_cost=True), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + # Verify that cost is None when completion_cost fails. + assert result.usage.cost is None + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_cost_zero_when_completion_cost_returns_zero(monkeypatch): + """Test that cost is 0 when completion_cost returns 0 (e.g., free model).""" + + async def fake_acompletion(model, messages=None, **kwargs): + msg = Message(role="assistant", content="Test response") + choice = Choices(index=0, message=msg) + response = ModelResponse( + choices=[choice], + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + return response + + def fake_completion_cost(completion_response): + """Mock completion_cost to return 0 (e.g., free model).""" + return 0.0 + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) + + model = LitellmModel(model="test-model", api_key="test-key") + result = await model.get_response( + system_instructions=None, + input=[], + model_settings=ModelSettings(track_cost=True), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + # Verify that cost is 0 for free models. + assert result.usage.cost == 0.0 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_cost_extraction_preserves_other_usage_fields(monkeypatch): + """Test that cost calculation doesn't affect other usage fields.""" + + async def fake_acompletion(model, messages=None, **kwargs): + msg = Message(role="assistant", content="Test response") + choice = Choices(index=0, message=msg) + response = ModelResponse( + choices=[choice], + usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), + ) + return response + + def fake_completion_cost(completion_response): + """Mock litellm.completion_cost to return a test cost value.""" + return 0.001 + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) + + model = LitellmModel(model="test-model", api_key="test-key") + result = await model.get_response( + system_instructions=None, + input=[], + model_settings=ModelSettings(track_cost=True), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + # Verify all usage fields are correct. + assert result.usage.input_tokens == 100 + assert result.usage.output_tokens == 50 + assert result.usage.total_tokens == 150 + assert result.usage.cost == 0.001 + assert result.usage.requests == 1 + + +def test_track_cost_sticky_through_resolve(): + """Test that track_cost=True is not overwritten by resolve() with empty override.""" + base = ModelSettings(track_cost=True, temperature=0.7) + override = ModelSettings(max_tokens=100) # Only setting max_tokens, track_cost is None + + resolved = base.resolve(override) + + # track_cost should remain True because override's track_cost is None (not False) + assert resolved.track_cost is True + assert resolved.temperature == 0.7 + assert resolved.max_tokens == 100 + + +def test_track_cost_can_be_explicitly_disabled(): + """Test that track_cost=True can be explicitly overridden to False.""" + base = ModelSettings(track_cost=True, temperature=0.7) + override = ModelSettings(track_cost=False, max_tokens=100) + + resolved = base.resolve(override) + + # track_cost should be False because override explicitly set it to False + assert resolved.track_cost is False + assert resolved.temperature == 0.7 + assert resolved.max_tokens == 100 diff --git a/tests/test_cost_in_run.py b/tests/test_cost_in_run.py new file mode 100644 index 000000000..9dce253fb --- /dev/null +++ b/tests/test_cost_in_run.py @@ -0,0 +1,105 @@ +"""Test cost extraction in run.py for streaming responses.""" + +from openai.types.responses import Response, ResponseOutputMessage, ResponseUsage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_extracts_cost_from_litellm_attribute(): + """Test that Usage extracts cost from Response._litellm_cost attribute.""" + # Simulate a Response object with _litellm_cost attached (as done by LitellmModel) + response = Response( + id="test-id", + created_at=123456, + model="test-model", + object="response", + output=[ + ResponseOutputMessage( + id="msg-1", + role="assistant", + type="message", + content=[], + status="completed", + ) + ], + usage=ResponseUsage( + input_tokens=100, + output_tokens=50, + total_tokens=150, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ), + tool_choice="auto", + parallel_tool_calls=False, + tools=[], + ) + + # Attach cost as LitellmModel does + response._litellm_cost = 0.00123 # type: ignore + + # Simulate what run.py does in ResponseCompletedEvent handling + cost = getattr(response, "_litellm_cost", None) + + assert response.usage is not None + usage = Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + cost=cost, + ) + + # Verify cost was extracted + assert usage.cost == 0.00123 + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + + +def test_usage_cost_none_when_attribute_missing(): + """Test that Usage.cost is None when _litellm_cost attribute is missing.""" + # Response without _litellm_cost attribute (normal OpenAI response) + response = Response( + id="test-id", + created_at=123456, + model="test-model", + object="response", + output=[ + ResponseOutputMessage( + id="msg-1", + role="assistant", + type="message", + content=[], + status="completed", + ) + ], + usage=ResponseUsage( + input_tokens=100, + output_tokens=50, + total_tokens=150, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), + tool_choice="auto", + parallel_tool_calls=False, + tools=[], + ) + + # Simulate what run.py does + cost = getattr(response, "_litellm_cost", None) + + assert response.usage is not None + usage = Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + cost=cost, + ) + + # Verify cost is None + assert usage.cost is None diff --git a/tests/test_usage.py b/tests/test_usage.py index 405f99ddf..4b6e74528 100644 --- a/tests/test_usage.py +++ b/tests/test_usage.py @@ -50,3 +50,52 @@ def test_usage_add_aggregates_with_none_values(): assert u1.total_tokens == 15 assert u1.input_tokens_details.cached_tokens == 4 assert u1.output_tokens_details.reasoning_tokens == 6 + + +def test_usage_cost_defaults_to_none(): + """Test that cost field defaults to None.""" + usage = Usage() + assert usage.cost is None + + +def test_usage_add_with_cost(): + """Test that cost is aggregated correctly when both usages have cost.""" + u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30, cost=0.001) + u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40, cost=0.002) + + u1.add(u2) + + assert u1.cost == 0.003 + assert u1.requests == 2 + assert u1.total_tokens == 70 + + +def test_usage_add_with_partial_cost(): + """Test that cost is preserved when only one usage has cost.""" + u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30, cost=0.001) + u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40) # no cost + + u1.add(u2) + + assert u1.cost == 0.001 + assert u1.requests == 2 + + +def test_usage_add_with_cost_none_plus_value(): + """Test that cost aggregation works when first usage has no cost.""" + u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30) + u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40, cost=0.002) + + u1.add(u2) + + assert u1.cost == 0.002 + + +def test_usage_add_with_both_cost_none(): + """Test that cost remains None when neither usage has cost.""" + u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30) + u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40) + + u1.add(u2) + + assert u1.cost is None