Skip to content

Commit e6af102

Browse files
authored
Autouse the set_event_loop fixture (#747)
1 parent 28e1d5a commit e6af102

File tree

9 files changed

+74
-75
lines changed

9 files changed

+74
-75
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def check_import() -> bool:
173173
import_success = True
174174

175175

176-
@pytest.fixture
176+
@pytest.fixture(autouse=True)
177177
def set_event_loop() -> Iterator[None]:
178178
new_loop = asyncio.new_event_loop()
179179
asyncio.set_event_loop(new_loop)

tests/models/test_model_function.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelRespon
5858
return ModelResponse(parts=[TextPart(' '.join(f'{k}={v!r}' for k, v in response.items()))])
5959

6060

61-
def test_simple(set_event_loop: None):
61+
def test_simple():
6262
agent = Agent(FunctionModel(return_last))
6363
result = agent.run_sync('Hello')
6464
assert result.data == snapshot("content='Hello' part_kind='user-prompt' message_count=1")
@@ -143,7 +143,7 @@ async def get_weather(_: RunContext[None], lat: int, lng: int):
143143
return 'Sunny'
144144

145145

146-
def test_weather(set_event_loop: None):
146+
def test_weather():
147147
result = weather_agent.run_sync('London')
148148
assert result.data == 'Raining in London'
149149
assert result.all_messages() == snapshot(
@@ -214,7 +214,7 @@ def get_var_args(ctx: RunContext[int], *args: int):
214214
return json.dumps({'args': args})
215215

216216

217-
def test_var_args(set_event_loop: None):
217+
def test_var_args():
218218
result = var_args_agent.run_sync('{"function": "get_var_args", "arguments": {"args": [1, 2, 3]}}', deps=123)
219219
response_data = json.loads(result.data)
220220
# Can't parse ISO timestamps with trailing 'Z' in older versions of python:
@@ -239,7 +239,7 @@ async def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRespo
239239
return ModelResponse(parts=[TextPart('final response')])
240240

241241

242-
def test_deps_none(set_event_loop: None):
242+
def test_deps_none():
243243
agent = Agent(FunctionModel(call_tool))
244244

245245
@agent.tool
@@ -259,7 +259,7 @@ async def get_none(ctx: RunContext[None]):
259259
assert called
260260

261261

262-
def test_deps_init(set_event_loop: None):
262+
def test_deps_init():
263263
def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
264264
nonlocal called
265265

@@ -274,7 +274,7 @@ def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
274274
assert called
275275

276276

277-
def test_model_arg(set_event_loop: None):
277+
def test_model_arg():
278278
agent = Agent()
279279
result = agent.run_sync('Hello', model=FunctionModel(return_last))
280280
assert result.data == snapshot("content='Hello' part_kind='user-prompt' message_count=1")
@@ -316,7 +316,7 @@ def spam() -> str:
316316
return 'foobar'
317317

318318

319-
def test_register_all(set_event_loop: None):
319+
def test_register_all():
320320
async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
321321
return ModelResponse(
322322
parts=[
@@ -330,7 +330,7 @@ async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
330330
assert result.data == snapshot('messages=1 allow_text_result=True tools=5')
331331

332332

333-
def test_call_all(set_event_loop: None):
333+
def test_call_all():
334334
result = agent_all.run_sync('Hello', model=TestModel())
335335
assert result.data == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
336336
assert result.all_messages() == snapshot(
@@ -370,7 +370,7 @@ def test_call_all(set_event_loop: None):
370370
)
371371

372372

373-
def test_retry_str(set_event_loop: None):
373+
def test_retry_str():
374374
call_count = 0
375375

376376
async def try_again(msgs_: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse:
@@ -392,7 +392,7 @@ async def validate_result(r: str) -> str:
392392
assert result.data == snapshot('2')
393393

394394

395-
def test_retry_result_type(set_event_loop: None):
395+
def test_retry_result_type():
396396
call_count = 0
397397

398398
async def try_again(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:

tests/models/test_model_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..conftest import IsNow
2727

2828

29-
def test_call_one(set_event_loop: None):
29+
def test_call_one():
3030
agent = Agent()
3131
calls: list[str] = []
3232

@@ -45,7 +45,7 @@ async def ret_b(x: str) -> str: # pragma: no cover
4545
assert calls == ['a']
4646

4747

48-
def test_custom_result_text(set_event_loop: None):
48+
def test_custom_result_text():
4949
agent = Agent()
5050
result = agent.run_sync('x', model=TestModel(custom_result_text='custom'))
5151
assert result.data == snapshot('custom')
@@ -54,13 +54,13 @@ def test_custom_result_text(set_event_loop: None):
5454
agent.run_sync('x', model=TestModel(custom_result_text='custom'))
5555

5656

57-
def test_custom_result_args(set_event_loop: None):
57+
def test_custom_result_args():
5858
agent = Agent(result_type=tuple[str, str])
5959
result = agent.run_sync('x', model=TestModel(custom_result_args=['a', 'b']))
6060
assert result.data == ('a', 'b')
6161

6262

63-
def test_custom_result_args_model(set_event_loop: None):
63+
def test_custom_result_args_model():
6464
class Foo(BaseModel):
6565
foo: str
6666
bar: int
@@ -70,13 +70,13 @@ class Foo(BaseModel):
7070
assert result.data == Foo(foo='a', bar=1)
7171

7272

73-
def test_result_type(set_event_loop: None):
73+
def test_result_type():
7474
agent = Agent(result_type=tuple[str, str])
7575
result = agent.run_sync('x', model=TestModel())
7676
assert result.data == ('a', 'a')
7777

7878

79-
def test_tool_retry(set_event_loop: None):
79+
def test_tool_retry():
8080
agent = Agent()
8181
call_count = 0
8282

@@ -120,7 +120,7 @@ async def my_ret(x: int) -> str:
120120
)
121121

122122

123-
def test_result_tool_retry_error_handled(set_event_loop: None):
123+
def test_result_tool_retry_error_handled():
124124
class ResultModel(BaseModel):
125125
x: int
126126
y: str

0 commit comments

Comments
 (0)