|
6 | 6 | from dataclasses import dataclass, field
|
7 | 7 | from datetime import timezone
|
8 | 8 | from functools import cached_property
|
9 |
| -from typing import Any, TypeVar, Union, cast |
| 9 | +from typing import Any, Callable, TypeVar, Union, cast |
10 | 10 |
|
11 | 11 | import httpx
|
12 | 12 | import pytest
|
|
52 | 52 | )
|
53 | 53 | from anthropic.types.raw_message_delta_event import Delta
|
54 | 54 |
|
55 |
| - from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings |
| 55 | + from pydantic_ai.models.anthropic import ( |
| 56 | + AnthropicModel, |
| 57 | + AnthropicModelSettings, |
| 58 | + _map_usage, # pyright: ignore[reportPrivateUsage] |
| 59 | + ) |
56 | 60 | from pydantic_ai.providers.anthropic import AnthropicProvider
|
57 | 61 |
|
58 | 62 | # note: we use Union here so that casting works with Python 3.9
|
@@ -921,3 +925,74 @@ def simple_instructions():
|
921 | 925 | ),
|
922 | 926 | ]
|
923 | 927 | )
|
| 928 | + |
| 929 | + |
| 930 | +def anth_msg(usage: AnthropicUsage) -> AnthropicMessage: |
| 931 | + return AnthropicMessage( |
| 932 | + id='x', |
| 933 | + content=[], |
| 934 | + model='claude-3-7-sonnet-latest', |
| 935 | + role='assistant', |
| 936 | + type='message', |
| 937 | + usage=usage, |
| 938 | + ) |
| 939 | + |
| 940 | + |
| 941 | +@pytest.mark.parametrize( |
| 942 | + 'message_callback,usage', |
| 943 | + [ |
| 944 | + pytest.param( |
| 945 | + lambda: anth_msg(AnthropicUsage(input_tokens=1, output_tokens=1)), |
| 946 | + snapshot( |
| 947 | + Usage( |
| 948 | + request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} |
| 949 | + ) |
| 950 | + ), |
| 951 | + id='AnthropicMessage', |
| 952 | + ), |
| 953 | + pytest.param( |
| 954 | + lambda: anth_msg( |
| 955 | + AnthropicUsage( |
| 956 | + input_tokens=1, output_tokens=1, cache_creation_input_tokens=2, cache_read_input_tokens=3 |
| 957 | + ) |
| 958 | + ), |
| 959 | + snapshot( |
| 960 | + Usage( |
| 961 | + request_tokens=6, |
| 962 | + response_tokens=1, |
| 963 | + total_tokens=7, |
| 964 | + details={ |
| 965 | + 'cache_creation_input_tokens': 2, |
| 966 | + 'cache_read_input_tokens': 3, |
| 967 | + 'input_tokens': 1, |
| 968 | + 'output_tokens': 1, |
| 969 | + }, |
| 970 | + ) |
| 971 | + ), |
| 972 | + id='AnthropicMessage-cached', |
| 973 | + ), |
| 974 | + pytest.param( |
| 975 | + lambda: RawMessageStartEvent( |
| 976 | + message=anth_msg(AnthropicUsage(input_tokens=1, output_tokens=1)), type='message_start' |
| 977 | + ), |
| 978 | + snapshot( |
| 979 | + Usage( |
| 980 | + request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1} |
| 981 | + ) |
| 982 | + ), |
| 983 | + id='RawMessageStartEvent', |
| 984 | + ), |
| 985 | + pytest.param( |
| 986 | + lambda: RawMessageDeltaEvent( |
| 987 | + delta=Delta(), |
| 988 | + usage=MessageDeltaUsage(output_tokens=5), |
| 989 | + type='message_delta', |
| 990 | + ), |
| 991 | + snapshot(Usage(response_tokens=5, total_tokens=5, details={'output_tokens': 5})), |
| 992 | + id='RawMessageDeltaEvent', |
| 993 | + ), |
| 994 | + pytest.param(lambda: RawMessageStopEvent(type='message_stop'), snapshot(Usage()), id='RawMessageStopEvent'), |
| 995 | + ], |
| 996 | +) |
| 997 | +def test_usage(message_callback: Callable[[], AnthropicMessage | RawMessageStreamEvent], usage: Usage): |
| 998 | + assert _map_usage(message_callback()) == usage |
0 commit comments