-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: #1831 Add opt-in cost tracking for LiteLLM models #1832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2e652bf
9c2c534
5329944
1c1cbe5
6a25c79
e10ed7a
bd2309e
883c5e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1179,6 +1179,9 @@ async def _run_single_turn_streamed( | |
prompt=prompt_config, | ||
): | ||
if isinstance(event, ResponseCompletedEvent): | ||
# Extract cost if it was attached by LiteLLM model. | ||
cost = getattr(event.response, "_litellm_cost", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, but I still hesitate to rely on LiteLLM's underscored property There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about storing the cost in a |
||
|
||
usage = ( | ||
Usage( | ||
requests=1, | ||
|
@@ -1187,6 +1190,7 @@ async def _run_single_turn_streamed( | |
total_tokens=event.response.usage.total_tokens, | ||
input_tokens_details=event.response.usage.input_tokens_details, | ||
output_tokens_details=event.response.usage.output_tokens_details, | ||
cost=cost, | ||
) | ||
if event.response.usage | ||
else Usage() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
"""Tests for LiteLLM cost tracking functionality.""" | ||
|
||
import litellm | ||
import pytest | ||
from litellm.types.utils import Choices, Message, ModelResponse, Usage | ||
|
||
from agents.extensions.models.litellm_model import LitellmModel | ||
from agents.model_settings import ModelSettings | ||
from agents.models.interface import ModelTracing | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_extracted_when_track_cost_enabled(monkeypatch): | ||
"""Test that cost is calculated using LiteLLM's completion_cost API when track_cost=True.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock litellm.completion_cost to return a test cost value.""" | ||
return 0.00042 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), # Enable cost tracking | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost was calculated. | ||
assert result.usage.cost == 0.00042 | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_none_when_track_cost_disabled(monkeypatch): | ||
"""Test that cost is None when track_cost is not set (defaults to None/False).""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
# Note: completion_cost should not be called when track_cost is None (default) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(), # track_cost defaults to None (disabled) | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is None when tracking is disabled. | ||
assert result.usage.cost is None | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_none_when_not_provided(monkeypatch): | ||
"""Test that cost is None when completion_cost raises an exception.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock completion_cost to raise an exception (e.g., unknown model).""" | ||
raise Exception("Model not found in pricing database") | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is None when completion_cost fails. | ||
assert result.usage.cost is None | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_zero_when_completion_cost_returns_zero(monkeypatch): | ||
"""Test that cost is 0 when completion_cost returns 0 (e.g., free model).""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock completion_cost to return 0 (e.g., free model).""" | ||
return 0.0 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is 0 for free models. | ||
assert result.usage.cost == 0.0 | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_extraction_preserves_other_usage_fields(monkeypatch): | ||
"""Test that cost calculation doesn't affect other usage fields.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock litellm.completion_cost to return a test cost value.""" | ||
return 0.001 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify all usage fields are correct. | ||
assert result.usage.input_tokens == 100 | ||
assert result.usage.output_tokens == 50 | ||
assert result.usage.total_tokens == 150 | ||
assert result.usage.cost == 0.001 | ||
assert result.usage.requests == 1 | ||
|
||
|
||
def test_track_cost_sticky_through_resolve(): | ||
"""Test that track_cost=True is not overwritten by resolve() with empty override.""" | ||
base = ModelSettings(track_cost=True, temperature=0.7) | ||
override = ModelSettings(max_tokens=100) # Only setting max_tokens, track_cost is None | ||
|
||
resolved = base.resolve(override) | ||
|
||
# track_cost should remain True because override's track_cost is None (not False) | ||
assert resolved.track_cost is True | ||
assert resolved.temperature == 0.7 | ||
assert resolved.max_tokens == 100 | ||
|
||
|
||
def test_track_cost_can_be_explicitly_disabled(): | ||
"""Test that track_cost=True can be explicitly overridden to False.""" | ||
base = ModelSettings(track_cost=True, temperature=0.7) | ||
override = ModelSettings(track_cost=False, max_tokens=100) | ||
|
||
resolved = base.resolve(override) | ||
|
||
# track_cost should be False because override explicitly set it to False | ||
assert resolved.track_cost is False | ||
assert resolved.temperature == 0.7 | ||
assert resolved.max_tokens == 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only related to LiteLLM use cases. Even if we decided to have this feature, I don't think this property should be here. If we add something like this, we may want to introduce LiteLLMSettings, which has only track_cost so far, and pass it to LiteLLMModel constructor as a new arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually like this idea. I thought it could be too big of a change, but if this is welcome I will go ahead with this.