Skip to content

Commit e10ed7a

Browse files
committed
code review: use litellm.completion_cost() for non-streaming
1 parent 6a25c79 commit e10ed7a

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,15 @@ async def get_response(
125125
if hasattr(response, "usage"):
126126
response_usage = response.usage
127127

128-
# Extract cost from LiteLLM's hidden params if cost tracking is enabled.
128+
# Calculate cost using LiteLLM's completion_cost function if cost tracking is enabled. # noqa: E501
129129
cost = None
130130
if model_settings.track_cost:
131-
if hasattr(response, "_hidden_params") and isinstance(
132-
response._hidden_params, dict
133-
):
134-
cost = response._hidden_params.get("response_cost")
131+
try:
132+
# Use LiteLLM's public API to calculate cost from the response.
133+
cost = litellm.completion_cost(completion_response=response) # type: ignore[attr-defined]
134+
except Exception:
135+
# If cost calculation fails (e.g., unknown model), continue without cost.
136+
pass
135137

136138
usage = (
137139
Usage(

tests/models/test_litellm_cost_tracking.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@pytest.mark.allow_call_model_methods
1313
@pytest.mark.asyncio
1414
async def test_cost_extracted_when_track_cost_enabled(monkeypatch):
15-
"""Test that cost is extracted from LiteLLM response when track_cost=True."""
15+
"""Test that cost is calculated using LiteLLM's completion_cost API when track_cost=True."""
1616

1717
async def fake_acompletion(model, messages=None, **kwargs):
1818
msg = Message(role="assistant", content="Test response")
@@ -21,11 +21,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
2121
choices=[choice],
2222
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
2323
)
24-
# Simulate LiteLLM's hidden params with cost.
25-
response._hidden_params = {"response_cost": 0.00042}
2624
return response
2725

26+
def fake_completion_cost(completion_response):
27+
"""Mock litellm.completion_cost to return a test cost value."""
28+
return 0.00042
29+
2830
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
31+
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)
2932

3033
model = LitellmModel(model="test-model", api_key="test-key")
3134
result = await model.get_response(
@@ -39,7 +42,7 @@ async def fake_acompletion(model, messages=None, **kwargs):
3942
previous_response_id=None,
4043
)
4144

42-
# Verify that cost was extracted.
45+
# Verify that cost was calculated.
4346
assert result.usage.cost == 0.00042
4447

4548

@@ -55,11 +58,10 @@ async def fake_acompletion(model, messages=None, **kwargs):
5558
choices=[choice],
5659
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
5760
)
58-
# Even if LiteLLM provides cost, it should be ignored.
59-
response._hidden_params = {"response_cost": 0.00042}
6061
return response
6162

6263
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
64+
# Note: completion_cost should not be called when track_cost=False
6365

6466
model = LitellmModel(model="test-model", api_key="test-key")
6567
result = await model.get_response(
@@ -80,7 +82,7 @@ async def fake_acompletion(model, messages=None, **kwargs):
8082
@pytest.mark.allow_call_model_methods
8183
@pytest.mark.asyncio
8284
async def test_cost_none_when_not_provided(monkeypatch):
83-
"""Test that cost is None when LiteLLM doesn't provide it."""
85+
"""Test that cost is None when completion_cost raises an exception."""
8486

8587
async def fake_acompletion(model, messages=None, **kwargs):
8688
msg = Message(role="assistant", content="Test response")
@@ -89,10 +91,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
8991
choices=[choice],
9092
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
9193
)
92-
# No _hidden_params or no cost in hidden params.
9394
return response
9495

96+
def fake_completion_cost(completion_response):
97+
"""Mock completion_cost to raise an exception (e.g., unknown model)."""
98+
raise Exception("Model not found in pricing database")
99+
95100
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
101+
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)
96102

97103
model = LitellmModel(model="test-model", api_key="test-key")
98104
result = await model.get_response(
@@ -106,14 +112,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
106112
previous_response_id=None,
107113
)
108114

109-
# Verify that cost is None when not provided.
115+
# Verify that cost is None when completion_cost fails.
110116
assert result.usage.cost is None
111117

112118

113119
@pytest.mark.allow_call_model_methods
114120
@pytest.mark.asyncio
115-
async def test_cost_with_empty_hidden_params(monkeypatch):
116-
"""Test that cost extraction handles empty _hidden_params gracefully."""
121+
async def test_cost_zero_when_completion_cost_returns_zero(monkeypatch):
122+
"""Test that cost is 0 when completion_cost returns 0 (e.g., free model)."""
117123

118124
async def fake_acompletion(model, messages=None, **kwargs):
119125
msg = Message(role="assistant", content="Test response")
@@ -122,11 +128,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
122128
choices=[choice],
123129
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
124130
)
125-
# Empty hidden params.
126-
response._hidden_params = {}
127131
return response
128132

133+
def fake_completion_cost(completion_response):
134+
"""Mock completion_cost to return 0 (e.g., free model)."""
135+
return 0.0
136+
129137
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
138+
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)
130139

131140
model = LitellmModel(model="test-model", api_key="test-key")
132141
result = await model.get_response(
@@ -140,14 +149,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
140149
previous_response_id=None,
141150
)
142151

143-
# Verify that cost is None with empty hidden params.
144-
assert result.usage.cost is None
152+
# Verify that cost is 0 for free models.
153+
assert result.usage.cost == 0.0
145154

146155

147156
@pytest.mark.allow_call_model_methods
148157
@pytest.mark.asyncio
149158
async def test_cost_extraction_preserves_other_usage_fields(monkeypatch):
150-
"""Test that cost extraction doesn't affect other usage fields."""
159+
"""Test that cost calculation doesn't affect other usage fields."""
151160

152161
async def fake_acompletion(model, messages=None, **kwargs):
153162
msg = Message(role="assistant", content="Test response")
@@ -156,10 +165,14 @@ async def fake_acompletion(model, messages=None, **kwargs):
156165
choices=[choice],
157166
usage=Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150),
158167
)
159-
response._hidden_params = {"response_cost": 0.001}
160168
return response
161169

170+
def fake_completion_cost(completion_response):
171+
"""Mock litellm.completion_cost to return a test cost value."""
172+
return 0.001
173+
162174
monkeypatch.setattr(litellm, "acompletion", fake_acompletion)
175+
monkeypatch.setattr(litellm, "completion_cost", fake_completion_cost)
163176

164177
model = LitellmModel(model="test-model", api_key="test-key")
165178
result = await model.get_response(

0 commit comments

Comments
 (0)