Skip to content

Commit cd206c5

Browse files
committed
add tests
1 parent 6d52ecf commit cd206c5

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class TestModel(Model):
8383
"""
8484
_model_name: str = field(default='test', repr=False)
8585
_system: str = field(default='test', repr=False)
86+
__provider_response_id: str = field(default='resp_test', repr=False)
8687

8788
def __init__(
8889
self,
@@ -132,6 +133,7 @@ async def request_stream(
132133
_structured_response=model_response,
133134
_messages=messages,
134135
_provider_name=self._system,
136+
_provider_response_id=self.__provider_response_id,
135137
)
136138

137139
@property
@@ -285,6 +287,7 @@ class TestStreamedResponse(StreamedResponse):
285287
_structured_response: ModelResponse
286288
_messages: InitVar[Iterable[ModelMessage]]
287289
_provider_name: str
290+
_provider_response_id: str
288291
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
289292

290293
def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -327,6 +330,11 @@ def model_name(self) -> str:
327330
"""Get the model name of the response."""
328331
return self._model_name
329332

333+
@property
334+
def provider_response_id(self) -> str:
335+
"""Get the provider name."""
336+
return self._provider_response_id
337+
330338
@property
331339
def provider_name(self) -> str:
332340
"""Get the provider name."""

tests/test_streaming.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def ret_a(x: str) -> str:
6868
model_name='test',
6969
timestamp=IsNow(tz=timezone.utc),
7070
provider_name='test',
71+
provider_response_id='resp_test',
7172
),
7273
ModelRequest(
7374
parts=[
@@ -98,6 +99,7 @@ async def ret_a(x: str) -> str:
9899
model_name='test',
99100
timestamp=IsNow(tz=timezone.utc),
100101
provider_name='test',
102+
provider_response_id='resp_test',
101103
),
102104
ModelRequest(
103105
parts=[
@@ -112,6 +114,7 @@ async def ret_a(x: str) -> str:
112114
model_name='test',
113115
timestamp=IsNow(tz=timezone.utc),
114116
provider_name='test',
117+
provider_response_id='resp_test',
115118
),
116119
]
117120
)
@@ -230,48 +233,55 @@ def upcase(text: str) -> str:
230233
model_name='test',
231234
timestamp=IsNow(tz=timezone.utc),
232235
provider_name='test',
236+
provider_response_id='resp_test',
233237
),
234238
ModelResponse(
235239
parts=[TextPart(content='The cat ')],
236240
usage=RequestUsage(input_tokens=51, output_tokens=2),
237241
model_name='test',
238242
timestamp=IsNow(tz=timezone.utc),
239243
provider_name='test',
244+
provider_response_id='resp_test',
240245
),
241246
ModelResponse(
242247
parts=[TextPart(content='The cat sat ')],
243248
usage=RequestUsage(input_tokens=51, output_tokens=3),
244249
model_name='test',
245250
timestamp=IsNow(tz=timezone.utc),
246251
provider_name='test',
252+
provider_response_id='resp_test',
247253
),
248254
ModelResponse(
249255
parts=[TextPart(content='The cat sat on ')],
250256
usage=RequestUsage(input_tokens=51, output_tokens=4),
251257
model_name='test',
252258
timestamp=IsNow(tz=timezone.utc),
253259
provider_name='test',
260+
provider_response_id='resp_test',
254261
),
255262
ModelResponse(
256263
parts=[TextPart(content='The cat sat on the ')],
257264
usage=RequestUsage(input_tokens=51, output_tokens=5),
258265
model_name='test',
259266
timestamp=IsNow(tz=timezone.utc),
260267
provider_name='test',
268+
provider_response_id='resp_test',
261269
),
262270
ModelResponse(
263271
parts=[TextPart(content='The cat sat on the mat.')],
264272
usage=RequestUsage(input_tokens=51, output_tokens=7),
265273
model_name='test',
266274
timestamp=IsNow(tz=timezone.utc),
267275
provider_name='test',
276+
provider_response_id='resp_test',
268277
),
269278
ModelResponse(
270279
parts=[TextPart(content='The cat sat on the mat.')],
271280
usage=RequestUsage(input_tokens=51, output_tokens=7),
272281
model_name='test',
273282
timestamp=IsNow(tz=timezone.utc),
274283
provider_name='test',
284+
provider_response_id='resp_test',
275285
),
276286
]
277287
)
@@ -796,6 +806,7 @@ def regular_tool(x: int) -> int:
796806
model_name='test',
797807
timestamp=IsNow(tz=timezone.utc),
798808
provider_name='test',
809+
provider_response_id='resp_test',
799810
),
800811
ModelRequest(
801812
parts=[
@@ -810,6 +821,7 @@ def regular_tool(x: int) -> int:
810821
model_name='test',
811822
timestamp=IsNow(tz=timezone.utc),
812823
provider_name='test',
824+
provider_response_id='resp_test',
813825
),
814826
ModelRequest(
815827
parts=[
@@ -914,6 +926,7 @@ def output_validator_simple(data: str) -> str:
914926
timestamp=IsNow(tz=timezone.utc),
915927
kind='response',
916928
provider_name='test',
929+
provider_response_id='resp_test',
917930
)
918931
for text in [
919932
'',
@@ -1197,6 +1210,7 @@ def my_tool(x: int) -> int:
11971210
model_name='test',
11981211
timestamp=IsDatetime(),
11991212
provider_name='test',
1213+
provider_response_id='resp_test',
12001214
)
12011215
]
12021216
)

tests/test_usage_limits.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ async def ret_a(x: str) -> str:
9797
model_name='test',
9898
timestamp=IsNow(tz=timezone.utc),
9999
provider_name='test',
100+
provider_response_id='resp_test',
100101
),
101102
ModelRequest(
102103
parts=[

0 commit comments

Comments
 (0)