Skip to content

Commit e4450cb

Browse files
authored
Fix run sync (#124)
1 parent 1cc4064 commit e4450cb

File tree

9 files changed

+71
-43
lines changed

9 files changed

+71
-43
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def run_sync(
206206
) -> result.RunResult[ResultData]:
207207
"""Run the agent with a user prompt synchronously.
208208
209-
This is a convenience method that wraps `self.run` with `asyncio.run()`.
209+
This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
210210
211211
Args:
212212
user_prompt: User input to start/continue the conversation.
@@ -217,7 +217,8 @@ def run_sync(
217217
Returns:
218218
The result of the run.
219219
"""
220-
return asyncio.run(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
220+
loop = asyncio.get_event_loop()
221+
return loop.run_until_complete(self.run(user_prompt, message_history=message_history, model=model, deps=deps))
221222

222223
@asynccontextmanager
223224
async def run_stream(

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import asyncio
34
import importlib.util
45
import os
56
import re
@@ -162,3 +163,11 @@ def check_import() -> bool:
162163
pass
163164
else:
164165
import_success = True
166+
167+
168+
@pytest.fixture
169+
def set_event_loop() -> Iterator[None]:
170+
new_loop = asyncio.new_event_loop()
171+
asyncio.set_event_loop(new_loop)
172+
yield
173+
new_loop.close()

tests/models/test_model_function.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def return_last(messages: list[Message], _: AgentInfo) -> ModelAnyResponse
3838
return ModelTextResponse(' '.join(f'{k}={v!r}' for k, v in response.items()))
3939

4040

41-
def test_simple():
41+
def test_simple(set_event_loop: None):
4242
agent = Agent(FunctionModel(return_last))
4343
result = agent.run_sync('Hello')
4444
assert result.data == snapshot("content='Hello' role='user' message_count=1")
@@ -129,7 +129,7 @@ async def get_weather(_: RunContext[None], lat: int, lng: int):
129129
return 'Sunny'
130130

131131

132-
def test_weather():
132+
def test_weather(set_event_loop: None):
133133
result = weather_agent.run_sync('London')
134134
assert result.data == 'Raining in London'
135135
assert result.all_messages() == snapshot(
@@ -206,7 +206,7 @@ def get_var_args(ctx: RunContext[int], *args: int):
206206
return json.dumps({'args': args})
207207

208208

209-
def test_var_args():
209+
def test_var_args(set_event_loop: None):
210210
result = var_args_agent.run_sync('{"function": "get_var_args", "arguments": {"args": [1, 2, 3]}}', deps=123)
211211
response_data = json.loads(result.data)
212212
# Can't parse ISO timestamps with trailing 'Z' in older versions of python:
@@ -231,7 +231,7 @@ async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyRespons
231231
return ModelTextResponse('final response')
232232

233233

234-
def test_deps_none():
234+
def test_deps_none(set_event_loop: None):
235235
agent = Agent(FunctionModel(call_tool))
236236

237237
@agent.tool
@@ -251,7 +251,7 @@ async def get_none(ctx: RunContext[None]):
251251
assert called
252252

253253

254-
def test_deps_init():
254+
def test_deps_init(set_event_loop: None):
255255
def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
256256
nonlocal called
257257

@@ -266,7 +266,7 @@ def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
266266
assert called
267267

268268

269-
def test_model_arg():
269+
def test_model_arg(set_event_loop: None):
270270
agent = Agent()
271271
result = agent.run_sync('Hello', model=FunctionModel(return_last))
272272
assert result.data == snapshot("content='Hello' role='user' message_count=1")
@@ -308,7 +308,7 @@ def spam() -> str:
308308
return 'foobar'
309309

310310

311-
def test_register_all():
311+
def test_register_all(set_event_loop: None):
312312
async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
313313
return ModelTextResponse(
314314
f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
@@ -318,7 +318,7 @@ async def f(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
318318
assert result.data == snapshot('messages=2 allow_text_result=True tools=5')
319319

320320

321-
def test_call_all():
321+
def test_call_all(set_event_loop: None):
322322
result = agent_all.run_sync('Hello', model=TestModel())
323323
assert result.data == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
324324
assert result.all_messages() == snapshot(
@@ -347,7 +347,7 @@ def test_call_all():
347347
)
348348

349349

350-
def test_retry_str():
350+
def test_retry_str(set_event_loop: None):
351351
call_count = 0
352352

353353
async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse:
@@ -369,7 +369,7 @@ async def validate_result(r: str) -> str:
369369
assert result.data == snapshot('2')
370370

371371

372-
def test_retry_result_type():
372+
def test_retry_result_type(set_event_loop: None):
373373
call_count = 0
374374

375375
async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse:

tests/models/test_model_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..conftest import IsNow
2525

2626

27-
def test_call_one():
27+
def test_call_one(set_event_loop: None):
2828
agent = Agent()
2929
calls: list[str] = []
3030

@@ -43,7 +43,7 @@ async def ret_b(x: str) -> str: # pragma: no cover
4343
assert calls == ['a']
4444

4545

46-
def test_custom_result_text():
46+
def test_custom_result_text(set_event_loop: None):
4747
agent = Agent()
4848
result = agent.run_sync('x', model=TestModel(custom_result_text='custom'))
4949
assert result.data == snapshot('custom')
@@ -52,13 +52,13 @@ def test_custom_result_text():
5252
agent.run_sync('x', model=TestModel(custom_result_text='custom'))
5353

5454

55-
def test_custom_result_args():
55+
def test_custom_result_args(set_event_loop: None):
5656
agent = Agent(result_type=tuple[str, str])
5757
result = agent.run_sync('x', model=TestModel(custom_result_args=['a', 'b']))
5858
assert result.data == ('a', 'b')
5959

6060

61-
def test_custom_result_args_model():
61+
def test_custom_result_args_model(set_event_loop: None):
6262
class Foo(BaseModel):
6363
foo: str
6464
bar: int
@@ -68,13 +68,13 @@ class Foo(BaseModel):
6868
assert result.data == Foo(foo='a', bar=1)
6969

7070

71-
def test_result_type():
71+
def test_result_type(set_event_loop: None):
7272
agent = Agent(result_type=tuple[str, str])
7373
result = agent.run_sync('x', model=TestModel())
7474
assert result.data == ('a', 'a')
7575

7676

77-
def test_tool_retry():
77+
def test_tool_retry(set_event_loop: None):
7878
agent = Agent()
7979
call_count = 0
8080

tests/test_agent.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
ToolReturn,
2121
UserPrompt,
2222
)
23+
from pydantic_ai.models import cached_async_http_client
2324
from pydantic_ai.models.function import AgentInfo, FunctionModel
2425
from pydantic_ai.models.test import TestModel
2526
from pydantic_ai.result import Cost, RunResult
2627

2728
from .conftest import IsNow, TestEnv
2829

2930

30-
def test_result_tuple():
31+
def test_result_tuple(set_event_loop: None):
3132
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
3233
assert info.result_tools is not None
3334
args_json = '{"response": ["foo", "bar"]}'
@@ -44,7 +45,7 @@ class Foo(BaseModel):
4445
b: str
4546

4647

47-
def test_result_pydantic_model():
48+
def test_result_pydantic_model(set_event_loop: None):
4849
def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
4950
assert info.result_tools is not None
5051
args_json = '{"a": 1, "b": "foo"}'
@@ -57,7 +58,7 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
5758
assert result.data.model_dump() == {'a': 1, 'b': 'foo'}
5859

5960

60-
def test_result_pydantic_model_retry():
61+
def test_result_pydantic_model_retry(set_event_loop: None):
6162
def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
6263
assert info.result_tools is not None
6364
if len(messages) == 1:
@@ -99,7 +100,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
99100
assert result.all_messages_json().startswith(b'[{"content":"Hello"')
100101

101102

102-
def test_result_validator():
103+
def test_result_validator(set_event_loop: None):
103104
def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
104105
assert info.result_tools is not None
105106
if len(messages) == 1:
@@ -135,7 +136,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
135136
)
136137

137138

138-
def test_plain_response():
139+
def test_plain_response(set_event_loop: None):
139140
call_index = 0
140141

141142
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
@@ -170,7 +171,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
170171
)
171172

172173

173-
def test_response_tuple():
174+
def test_response_tuple(set_event_loop: None):
174175
m = TestModel()
175176

176177
agent = Agent(m, result_type=tuple[str, str])
@@ -215,7 +216,7 @@ def test_response_tuple():
215216
[lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str],
216217
ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'],
217218
)
218-
def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
219+
def test_response_union_allow_str(set_event_loop: None, input_union_callable: Callable[[], Any]):
219220
try:
220221
union = input_union_callable()
221222
except TypeError:
@@ -277,7 +278,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
277278
),
278279
],
279280
)
280-
def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
281+
def test_response_multiple_return_tools(set_event_loop: None, create_module: Callable[[str], Any], union_code: str):
281282
module_code = f'''
282283
from pydantic import BaseModel
283284
from typing import Union
@@ -356,7 +357,7 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
356357
assert got_tool_call_name == snapshot('final_result_Bar')
357358

358359

359-
def test_run_with_history_new():
360+
def test_run_with_history_new(set_event_loop: None):
360361
m = TestModel()
361362

362363
agent = Agent(m, system_prompt='Foobar')
@@ -440,7 +441,7 @@ async def ret_a(x: str) -> str:
440441
)
441442

442443

443-
def test_empty_tool_calls():
444+
def test_empty_tool_calls(set_event_loop: None):
444445
def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
445446
return ModelStructuredResponse(calls=[])
446447

@@ -450,7 +451,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
450451
agent.run_sync('Hello')
451452

452453

453-
def test_unknown_tool():
454+
def test_unknown_tool(set_event_loop: None):
454455
def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
455456
return ModelStructuredResponse(calls=[ToolCall.from_json('foobar', '{}')])
456457

@@ -472,7 +473,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelAnyResponse:
472473
)
473474

474475

475-
def test_unknown_tool_fix():
476+
def test_unknown_tool_fix(set_event_loop: None):
476477
def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
477478
if len(m) > 1:
478479
return ModelTextResponse(content='success')
@@ -495,15 +496,15 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelAnyResponse:
495496
)
496497

497498

498-
def test_model_requests_blocked(env: TestEnv):
499+
def test_model_requests_blocked(env: TestEnv, set_event_loop: None):
499500
env.set('GEMINI_API_KEY', 'foobar')
500501
agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)
501502

502503
with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
503504
agent.run_sync('Hello')
504505

505506

506-
def test_override_model(env: TestEnv):
507+
def test_override_model(env: TestEnv, set_event_loop: None):
507508
env.set('GEMINI_API_KEY', 'foobar')
508509
agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)
509510

@@ -512,9 +513,25 @@ def test_override_model(env: TestEnv):
512513
assert result.data == snapshot((0, 'a'))
513514

514515

515-
def test_override_model_no_model():
516+
def test_override_model_no_model(set_event_loop: None):
516517
agent = Agent()
517518

518519
with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'):
519520
with agent.override(model='test'):
520521
agent.run_sync('Hello')
522+
523+
524+
def test_run_sync_multiple(set_event_loop: None):
525+
agent = Agent('test')
526+
527+
@agent.tool_plain
528+
async def make_request() -> str:
529+
# raised a `RuntimeError: Event loop is closed` on repeat runs when we used `asyncio.run()`
530+
client = cached_async_http_client()
531+
# use this as I suspect it's about the fastest globally available endpoint
532+
response = await client.get('https://cloudflare.com/cdn-cgi/trace')
533+
return str(response.status_code)
534+
535+
for _ in range(2):
536+
result = agent.run_sync('Hello')
537+
assert result.data == '{"make_request":"200"}'

tests/test_deps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ async def example_tool(ctx: RunContext[MyDeps]) -> str:
1818
return f'{ctx.deps}'
1919

2020

21-
def test_deps_used():
21+
def test_deps_used(set_event_loop: None):
2222
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
2323
assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'
2424

2525

26-
def test_deps_override():
26+
def test_deps_override(set_event_loop: None):
2727
with agent.override(deps=MyDeps(foo=3, bar=4)):
2828
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
2929
assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'

tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_docs_examples(
6161
client_with_handler: ClientWithHandler,
6262
env: TestEnv,
6363
tmp_path: Path,
64+
set_event_loop: None,
6465
):
6566
mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
6667
mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal)

tests/test_logfire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_summary() -> LogfireSummary:
5959

6060

6161
@pytest.mark.skipif(not logfire_installed, reason='logfire not installed')
62-
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary]) -> None:
62+
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary], set_event_loop: None) -> None:
6363
agent = Agent(model=TestModel())
6464

6565
@agent.tool_plain

0 commit comments

Comments
 (0)