Skip to content

Commit 2e652bf

Browse files
committed
add litellm cost tracking feature to Usage objects through track_cost model setting
1 parent 14d7d59 commit 2e652bf

File tree

5 files changed

+136
-2
lines changed

5 files changed

+136
-2
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ async def get_response(
124124

125125
if hasattr(response, "usage"):
126126
response_usage = response.usage
127+
128+
# Extract cost from LiteLLM's hidden params if cost tracking is enabled.
129+
cost = None
130+
if model_settings.track_cost:
131+
if hasattr(response, "_hidden_params") and isinstance(
132+
response._hidden_params, dict
133+
):
134+
cost = response._hidden_params.get("response_cost")
135+
127136
usage = (
128137
Usage(
129138
requests=1,
@@ -142,6 +151,7 @@ async def get_response(
142151
)
143152
or 0
144153
),
154+
cost=cost,
145155
)
146156
if response.usage
147157
else Usage()
@@ -201,10 +211,67 @@ async def stream_response(
201211

202212
final_response: Response | None = None
203213
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
204-
yield chunk
205-
214+
# Intercept the response.completed event to calculate and attach cost.
206215
if chunk.type == "response.completed":
207216
final_response = chunk.response
217+
# Calculate cost using LiteLLM's completion_cost function if enabled.
218+
# Streaming responses don't include cost in _hidden_params, so we
219+
# calculate it from the final token counts.
220+
if model_settings.track_cost and final_response.usage:
221+
try:
222+
# Create a mock ModelResponse for cost calculation.
223+
# Include token details (cached, reasoning) for accurate pricing.
224+
from litellm.types.utils import (
225+
Choices as LitellmChoices,
226+
CompletionTokensDetailsWrapper,
227+
Message as LitellmMessage,
228+
ModelResponse as LitellmModelResponse,
229+
PromptTokensDetailsWrapper,
230+
Usage as LitellmUsage,
231+
)
232+
233+
# Extract token details for accurate cost calculation.
234+
cached_tokens = (
235+
final_response.usage.input_tokens_details.cached_tokens
236+
if final_response.usage.input_tokens_details
237+
else 0
238+
)
239+
reasoning_tokens = (
240+
final_response.usage.output_tokens_details.reasoning_tokens
241+
if final_response.usage.output_tokens_details
242+
else 0
243+
)
244+
245+
mock_response = LitellmModelResponse(
246+
choices=[
247+
LitellmChoices(
248+
index=0,
249+
message=LitellmMessage(role="assistant", content=""),
250+
)
251+
],
252+
usage=LitellmUsage(
253+
prompt_tokens=final_response.usage.input_tokens,
254+
completion_tokens=final_response.usage.output_tokens,
255+
total_tokens=final_response.usage.total_tokens,
256+
prompt_tokens_details=PromptTokensDetailsWrapper(
257+
cached_tokens=cached_tokens
258+
),
259+
completion_tokens_details=CompletionTokensDetailsWrapper(
260+
reasoning_tokens=reasoning_tokens
261+
),
262+
),
263+
model=self.model,
264+
)
265+
cost = litellm.completion_cost(completion_response=mock_response)
266+
# Attach cost as a custom attribute on the Response object so
267+
# run.py can access it when creating the Usage object.
268+
final_response._litellm_cost = cost
269+
except Exception:
270+
# If cost calculation fails (e.g., unknown model), continue
271+
# without cost.
272+
pass
273+
274+
yield chunk
208275

209276
if tracing.include_data() and final_response:
210277
span_generation.span_data.output = [final_response.model_dump()]

src/agents/model_settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ class ModelSettings:
120120
"""Whether to include usage chunk.
121121
Only available for Chat Completions API."""
122122

123+
track_cost: bool = False
124+
"""Whether to track and calculate cost for model calls.
125+
When enabled, the SDK will populate `Usage.cost` with cost estimates.
126+
Currently only supported for LiteLLM models. For other providers, cost will remain None.
127+
Defaults to False."""
128+
123129
# TODO: revisit ResponseIncludable | str if ResponseIncludable covers more cases
124130
# We've added str to support missing ones like
125131
# "web_search_call.action.sources" etc.

src/agents/run.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,9 @@ async def _run_single_turn_streamed(
11041104
prompt=prompt_config,
11051105
):
11061106
if isinstance(event, ResponseCompletedEvent):
1107+
# Extract cost if it was attached by LiteLLM model.
1108+
cost = getattr(event.response, "_litellm_cost", None)
1109+
11071110
usage = (
11081111
Usage(
11091112
requests=1,
@@ -1112,6 +1115,7 @@ async def _run_single_turn_streamed(
11121115
total_tokens=event.response.usage.total_tokens,
11131116
input_tokens_details=event.response.usage.input_tokens_details,
11141117
output_tokens_details=event.response.usage.output_tokens_details,
1118+
cost=cost,
11151119
)
11161120
if event.response.usage
11171121
else Usage()

src/agents/usage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class Usage:
2727
total_tokens: int = 0
2828
"""Total tokens sent and received, across all requests."""
2929

30+
cost: float | None = None
31+
"""Total cost in USD for the requests. Only available for certain model providers
32+
(e.g., LiteLLM). Will be None for models that don't provide cost information."""
33+
3034
def add(self, other: "Usage") -> None:
3135
self.requests += other.requests if other.requests else 0
3236
self.input_tokens += other.input_tokens if other.input_tokens else 0
@@ -41,3 +45,7 @@ def add(self, other: "Usage") -> None:
4145
reasoning_tokens=self.output_tokens_details.reasoning_tokens
4246
+ other.output_tokens_details.reasoning_tokens
4347
)
48+
49+
# Aggregate cost if either has a value.
50+
if self.cost is not None or other.cost is not None:
51+
self.cost = (self.cost or 0.0) + (other.cost or 0.0)

tests/test_usage.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,52 @@ def test_usage_add_aggregates_with_none_values():
5050
assert u1.total_tokens == 15
5151
assert u1.input_tokens_details.cached_tokens == 4
5252
assert u1.output_tokens_details.reasoning_tokens == 6
53+
54+
55+
def test_usage_cost_defaults_to_none():
56+
"""Test that cost field defaults to None."""
57+
usage = Usage()
58+
assert usage.cost is None
59+
60+
61+
def test_usage_add_with_cost():
62+
"""Test that cost is aggregated correctly when both usages have cost."""
63+
u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30, cost=0.001)
64+
u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40, cost=0.002)
65+
66+
u1.add(u2)
67+
68+
assert u1.cost == 0.003
69+
assert u1.requests == 2
70+
assert u1.total_tokens == 70
71+
72+
73+
def test_usage_add_with_partial_cost():
74+
"""Test that cost is preserved when only one usage has cost."""
75+
u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30, cost=0.001)
76+
u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40) # no cost
77+
78+
u1.add(u2)
79+
80+
assert u1.cost == 0.001
81+
assert u1.requests == 2
82+
83+
84+
def test_usage_add_with_cost_none_plus_value():
85+
"""Test that cost aggregation works when first usage has no cost."""
86+
u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30)
87+
u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40, cost=0.002)
88+
89+
u1.add(u2)
90+
91+
assert u1.cost == 0.002
92+
93+
94+
def test_usage_add_with_both_cost_none():
95+
"""Test that cost remains None when neither usage has cost."""
96+
u1 = Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30)
97+
u2 = Usage(requests=1, input_tokens=15, output_tokens=25, total_tokens=40)
98+
99+
u1.add(u2)
100+
101+
assert u1.cost is None

0 commit comments

Comments
 (0)