Skip to content

Commit dcfa821

Browse files
authored
Store additional usage details from Anthropic (#1549)
1 parent 7d613c5 commit dcfa821

File tree

2 files changed

+86
-9
lines changed

2 files changed

+86
-9
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,27 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
409409
if response_usage is None:
410410
return usage.Usage()
411411

412-
request_tokens = getattr(response_usage, 'input_tokens', None)
412+
# Store all integer-typed usage values in the details dict
413+
response_usage_dict = response_usage.model_dump()
414+
details: dict[str, int] = {}
415+
for key, value in response_usage_dict.items():
416+
if isinstance(value, int):
417+
details[key] = value
418+
419+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence the getattr call
420+
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
421+
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
422+
request_tokens = (
423+
getattr(response_usage, 'input_tokens', 0)
424+
+ (getattr(response_usage, 'cache_creation_input_tokens', 0) or 0) # These can be missing, None, or int
425+
+ (getattr(response_usage, 'cache_read_input_tokens', 0) or 0)
426+
)
413427

414428
return usage.Usage(
415-
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
416-
request_tokens=request_tokens,
429+
request_tokens=request_tokens or None,
417430
response_tokens=response_usage.output_tokens,
418-
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
431+
total_tokens=request_tokens + response_usage.output_tokens,
432+
details=details or None,
419433
)
420434

421435

tests/models/test_anthropic.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,29 @@ async def test_sync_request_text_response(allow_model_requests: None):
141141

142142
result = await agent.run('hello')
143143
assert result.output == 'world'
144-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
145-
144+
assert result.usage() == snapshot(
145+
Usage(
146+
requests=1,
147+
request_tokens=5,
148+
response_tokens=10,
149+
total_tokens=15,
150+
details={'input_tokens': 5, 'output_tokens': 10},
151+
)
152+
)
146153
# reset the index so we get the same response again
147154
mock_client.index = 0 # type: ignore
148155

149156
result = await agent.run('hello', message_history=result.new_messages())
150157
assert result.output == 'world'
151-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
158+
assert result.usage() == snapshot(
159+
Usage(
160+
requests=1,
161+
request_tokens=5,
162+
response_tokens=10,
163+
total_tokens=15,
164+
details={'input_tokens': 5, 'output_tokens': 10},
165+
)
166+
)
152167
assert result.all_messages() == snapshot(
153168
[
154169
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
@@ -167,6 +182,38 @@ async def test_sync_request_text_response(allow_model_requests: None):
167182
)
168183

169184

185+
async def test_async_request_prompt_caching(allow_model_requests: None):
186+
c = completion_message(
187+
[TextBlock(text='world', type='text')],
188+
usage=AnthropicUsage(
189+
input_tokens=3,
190+
output_tokens=5,
191+
cache_creation_input_tokens=4,
192+
cache_read_input_tokens=6,
193+
),
194+
)
195+
mock_client = MockAnthropic.create_mock(c)
196+
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
197+
agent = Agent(m)
198+
199+
result = await agent.run('hello')
200+
assert result.output == 'world'
201+
assert result.usage() == snapshot(
202+
Usage(
203+
requests=1,
204+
request_tokens=13,
205+
response_tokens=5,
206+
total_tokens=18,
207+
details={
208+
'input_tokens': 3,
209+
'output_tokens': 5,
210+
'cache_creation_input_tokens': 4,
211+
'cache_read_input_tokens': 6,
212+
},
213+
)
214+
)
215+
216+
170217
async def test_async_request_text_response(allow_model_requests: None):
171218
c = completion_message(
172219
[TextBlock(text='world', type='text')],
@@ -178,7 +225,15 @@ async def test_async_request_text_response(allow_model_requests: None):
178225

179226
result = await agent.run('hello')
180227
assert result.output == 'world'
181-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=3, response_tokens=5, total_tokens=8))
228+
assert result.usage() == snapshot(
229+
Usage(
230+
requests=1,
231+
request_tokens=3,
232+
response_tokens=5,
233+
total_tokens=8,
234+
details={'input_tokens': 3, 'output_tokens': 5},
235+
)
236+
)
182237

183238

184239
async def test_request_structured_response(allow_model_requests: None):
@@ -551,7 +606,15 @@ async def my_tool(first: str, second: str) -> int:
551606
]
552607
)
553608
assert result.is_complete
554-
assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25))
609+
assert result.usage() == snapshot(
610+
Usage(
611+
requests=2,
612+
request_tokens=20,
613+
response_tokens=5,
614+
total_tokens=25,
615+
details={'input_tokens': 20, 'output_tokens': 5},
616+
)
617+
)
555618
assert tool_called
556619

557620

0 commit comments

Comments
 (0)