20
20
ToolReturn ,
21
21
UserPrompt ,
22
22
)
23
+ from pydantic_ai .models import cached_async_http_client
23
24
from pydantic_ai .models .function import AgentInfo , FunctionModel
24
25
from pydantic_ai .models .test import TestModel
25
26
from pydantic_ai .result import Cost , RunResult
26
27
27
28
from .conftest import IsNow , TestEnv
28
29
29
30
30
- def test_result_tuple ():
31
+ def test_result_tuple (set_event_loop : None ):
31
32
def return_tuple (_ : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
32
33
assert info .result_tools is not None
33
34
args_json = '{"response": ["foo", "bar"]}'
@@ -44,7 +45,7 @@ class Foo(BaseModel):
44
45
b : str
45
46
46
47
47
- def test_result_pydantic_model ():
48
+ def test_result_pydantic_model (set_event_loop : None ):
48
49
def return_model (_ : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
49
50
assert info .result_tools is not None
50
51
args_json = '{"a": 1, "b": "foo"}'
@@ -57,7 +58,7 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
57
58
assert result .data .model_dump () == {'a' : 1 , 'b' : 'foo' }
58
59
59
60
60
- def test_result_pydantic_model_retry ():
61
+ def test_result_pydantic_model_retry (set_event_loop : None ):
61
62
def return_model (messages : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
62
63
assert info .result_tools is not None
63
64
if len (messages ) == 1 :
@@ -99,7 +100,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
99
100
assert result .all_messages_json ().startswith (b'[{"content":"Hello"' )
100
101
101
102
102
- def test_result_validator ():
103
+ def test_result_validator (set_event_loop : None ):
103
104
def return_model (messages : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
104
105
assert info .result_tools is not None
105
106
if len (messages ) == 1 :
@@ -135,7 +136,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
135
136
)
136
137
137
138
138
- def test_plain_response ():
139
+ def test_plain_response (set_event_loop : None ):
139
140
call_index = 0
140
141
141
142
def return_tuple (_ : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
@@ -170,7 +171,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
170
171
)
171
172
172
173
173
- def test_response_tuple ():
174
+ def test_response_tuple (set_event_loop : None ):
174
175
m = TestModel ()
175
176
176
177
agent = Agent (m , result_type = tuple [str , str ])
@@ -215,7 +216,7 @@ def test_response_tuple():
215
216
[lambda : Union [str , Foo ], lambda : Union [Foo , str ], lambda : str | Foo , lambda : Foo | str ],
216
217
ids = ['Union[str, Foo]' , 'Union[Foo, str]' , 'str | Foo' , 'Foo | str' ],
217
218
)
218
- def test_response_union_allow_str (input_union_callable : Callable [[], Any ]):
219
+ def test_response_union_allow_str (set_event_loop : None , input_union_callable : Callable [[], Any ]):
219
220
try :
220
221
union = input_union_callable ()
221
222
except TypeError :
@@ -277,7 +278,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
277
278
),
278
279
],
279
280
)
280
- def test_response_multiple_return_tools (create_module : Callable [[str ], Any ], union_code : str ):
281
+ def test_response_multiple_return_tools (set_event_loop : None , create_module : Callable [[str ], Any ], union_code : str ):
281
282
module_code = f'''
282
283
from pydantic import BaseModel
283
284
from typing import Union
@@ -356,7 +357,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
356
357
assert got_tool_call_name == snapshot ('final_result_Bar' )
357
358
358
359
359
- def test_run_with_history_new ():
360
+ def test_run_with_history_new (set_event_loop : None ):
360
361
m = TestModel ()
361
362
362
363
agent = Agent (m , system_prompt = 'Foobar' )
@@ -440,7 +441,7 @@ async def ret_a(x: str) -> str:
440
441
)
441
442
442
443
443
- def test_empty_tool_calls ():
444
+ def test_empty_tool_calls (set_event_loop : None ):
444
445
def empty (_ : list [Message ], _info : AgentInfo ) -> ModelAnyResponse :
445
446
return ModelStructuredResponse (calls = [])
446
447
@@ -450,7 +451,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
450
451
agent .run_sync ('Hello' )
451
452
452
453
453
- def test_unknown_tool ():
454
+ def test_unknown_tool (set_event_loop : None ):
454
455
def empty (_ : list [Message ], _info : AgentInfo ) -> ModelAnyResponse :
455
456
return ModelStructuredResponse (calls = [ToolCall .from_json ('foobar' , '{}' )])
456
457
@@ -472,7 +473,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
472
473
)
473
474
474
475
475
- def test_unknown_tool_fix ():
476
+ def test_unknown_tool_fix (set_event_loop : None ):
476
477
def empty (m : list [Message ], _info : AgentInfo ) -> ModelAnyResponse :
477
478
if len (m ) > 1 :
478
479
return ModelTextResponse (content = 'success' )
@@ -495,15 +496,15 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
495
496
)
496
497
497
498
498
- def test_model_requests_blocked (env : TestEnv ):
499
+ def test_model_requests_blocked (env : TestEnv , set_event_loop : None ):
499
500
env .set ('GEMINI_API_KEY' , 'foobar' )
500
501
agent = Agent ('gemini-1.5-flash' , result_type = tuple [str , str ], defer_model_check = True )
501
502
502
503
with pytest .raises (RuntimeError , match = 'Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False' ):
503
504
agent .run_sync ('Hello' )
504
505
505
506
506
- def test_override_model (env : TestEnv ):
507
+ def test_override_model (env : TestEnv , set_event_loop : None ):
507
508
env .set ('GEMINI_API_KEY' , 'foobar' )
508
509
agent = Agent ('gemini-1.5-flash' , result_type = tuple [int , str ], defer_model_check = True )
509
510
@@ -512,9 +513,25 @@ def test_override_model(env: TestEnv):
512
513
assert result .data == snapshot ((0 , 'a' ))
513
514
514
515
515
- def test_override_model_no_model ():
516
+ def test_override_model_no_model (set_event_loop : None ):
516
517
agent = Agent ()
517
518
518
519
with pytest .raises (UserError , match = r'`model` must be set either.+Even when `override\(model=...\)` is customiz' ):
519
520
with agent .override (model = 'test' ):
520
521
agent .run_sync ('Hello' )
522
+
523
+
524
+ def test_run_sync_multiple (set_event_loop : None ):
525
+ agent = Agent ('test' )
526
+
527
+ @agent .tool_plain
528
+ async def make_request () -> str :
529
+ # raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()`
530
+ client = cached_async_http_client ()
531
+ # use this as I suspect it's about the fastest globally available endpoint
532
+ response = await client .get ('https://cloudflare.com/cdn-cgi/trace' )
533
+ return str (response .status_code )
534
+
535
+ for _ in range (2 ):
536
+ result = agent .run_sync ('Hello' )
537
+ assert result .data == '{"make_request":"200"}'
0 commit comments