-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: #1831 Add opt-in cost tracking for LiteLLM models #1832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
2e652bf
9c2c534
5329944
1c1cbe5
6a25c79
e10ed7a
bd2309e
883c5e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1104,6 +1104,9 @@ async def _run_single_turn_streamed( | |
prompt=prompt_config, | ||
): | ||
if isinstance(event, ResponseCompletedEvent): | ||
# Extract cost if it was attached by LiteLLM model. | ||
cost = getattr(event.response, "_litellm_cost", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, but I still hesitate to rely on LiteLLM's underscored property There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about storing the cost in a |
||
|
||
usage = ( | ||
Usage( | ||
requests=1, | ||
|
@@ -1112,6 +1115,7 @@ async def _run_single_turn_streamed( | |
total_tokens=event.response.usage.total_tokens, | ||
input_tokens_details=event.response.usage.input_tokens_details, | ||
output_tokens_details=event.response.usage.output_tokens_details, | ||
cost=cost, | ||
) | ||
if event.response.usage | ||
else Usage() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
"""Tests for LiteLLM cost tracking functionality.""" | ||
|
||
import litellm | ||
import pytest | ||
from litellm.types.utils import Choices, Message, ModelResponse, Usage | ||
|
||
from agents.extensions.models.litellm_model import LitellmModel | ||
from agents.model_settings import ModelSettings | ||
from agents.models.interface import ModelTracing | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_extracted_when_track_cost_enabled(monkeypatch): | ||
"""Test that cost is calculated using LiteLLM's completion_cost API when track_cost=True.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock litellm.completion_cost to return a test cost value.""" | ||
return 0.00042 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), # Enable cost tracking | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost was calculated. | ||
assert result.usage.cost == 0.00042 | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_none_when_track_cost_disabled(monkeypatch): | ||
"""Test that cost is None when track_cost=False (default).""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
# Note: completion_cost should not be called when track_cost=False | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=False), # Disabled (default) | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is None when tracking is disabled. | ||
assert result.usage.cost is None | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_none_when_not_provided(monkeypatch): | ||
"""Test that cost is None when completion_cost raises an exception.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock completion_cost to raise an exception (e.g., unknown model).""" | ||
raise Exception("Model not found in pricing database") | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is None when completion_cost fails. | ||
assert result.usage.cost is None | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_zero_when_completion_cost_returns_zero(monkeypatch): | ||
"""Test that cost is 0 when completion_cost returns 0 (e.g., free model).""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock completion_cost to return 0 (e.g., free model).""" | ||
return 0.0 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify that cost is 0 for free models. | ||
assert result.usage.cost == 0.0 | ||
|
||
|
||
@pytest.mark.allow_call_model_methods | ||
@pytest.mark.asyncio | ||
async def test_cost_extraction_preserves_other_usage_fields(monkeypatch): | ||
"""Test that cost calculation doesn't affect other usage fields.""" | ||
|
||
async def fake_acompletion(model, messages=None, **kwargs): | ||
msg = Message(role="assistant", content="Test response") | ||
choice = Choices(index=0, message=msg) | ||
response = ModelResponse( | ||
choices=[choice], | ||
usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150), | ||
) | ||
return response | ||
|
||
def fake_completion_cost(completion_response): | ||
"""Mock litellm.completion_cost to return a test cost value.""" | ||
return 0.001 | ||
|
||
monkeypatch.setattr(litellm, "acompletion", fake_acompletion) | ||
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost) | ||
|
||
model = LitellmModel(model="test-model", api_key="test-key") | ||
result = await model.get_response( | ||
system_instructions=None, | ||
input=[], | ||
model_settings=ModelSettings(track_cost=True), | ||
tools=[], | ||
output_schema=None, | ||
handoffs=[], | ||
tracing=ModelTracing.DISABLED, | ||
previous_response_id=None, | ||
) | ||
|
||
# Verify all usage fields are correct. | ||
assert result.usage.input_tokens == 100 | ||
assert result.usage.output_tokens == 50 | ||
assert result.usage.total_tokens == 150 | ||
assert result.usage.cost == 0.001 | ||
assert result.usage.requests == 1 |
Uh oh!
There was an error while loading. Please reload this page.