Skip to content

Commit 4fba010

Browse files
committed
[refactor] Introduced new item type "AsyncFunction" which represents pytest.Functions that are coroutines or async generators.
Signed-off-by: Michael Seifert <[email protected]>
1 parent 6012780 commit 4fba010

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

pytest_asyncio/plugin.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,17 @@ async def setup():
353353
fixturedef.func = _async_fixture_wrapper
354354

355355

356+
class AsyncFunction(pytest.Function):
357+
"""Pytest item that is a coroutine or an asynchronous generator"""
358+
359+
356360
_HOLDER: Set[FixtureDef] = set()
357361

358362

359-
@pytest.hookimpl(tryfirst=True)
360-
def pytest_pycollect_makeitem(
363+
# The function name needs to start with "pytest_"
364+
# see https://github.com/pytest-dev/pytest/issues/11307
365+
@pytest.hookimpl(specname="pytest_pycollect_makeitem", tryfirst=True)
366+
def pytest_pycollect_makeitem_preprocess_async_fixtures(
361367
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
362368
) -> Union[
363369
pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]], None
@@ -369,6 +375,50 @@ def pytest_pycollect_makeitem(
369375
return None
370376

371377

378+
# The function name needs to start with "pytest_"
379+
# see https://github.com/pytest-dev/pytest/issues/11307
380+
@pytest.hookimpl(specname="pytest_pycollect_makeitem", hookwrapper=True)
381+
def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
382+
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
383+
) -> Union[
384+
pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]], None
385+
]:
386+
"""
387+
Converts coroutines and async generators collected as pytest.Functions
388+
to AsyncFunction items.
389+
"""
390+
hook_result = yield
391+
node_or_list_of_nodes = hook_result.get_result()
392+
if not node_or_list_of_nodes:
393+
return
394+
try:
395+
node_iterator = iter(node_or_list_of_nodes)
396+
except TypeError:
397+
# Treat single node as a single-element iterable
398+
node_iterator = iter((node_or_list_of_nodes,))
399+
async_functions = []
400+
for collector_or_item in node_iterator:
401+
if not (
402+
isinstance(collector_or_item, pytest.Function)
403+
and _is_coroutine_or_asyncgen(obj)
404+
):
405+
collector = collector_or_item
406+
async_functions.append(collector)
407+
continue
408+
item = collector_or_item
409+
async_function = AsyncFunction.from_parent(
410+
item.parent,
411+
name=item.name,
412+
callspec=getattr(item, "callspec", None),
413+
callobj=item.obj,
414+
fixtureinfo=item._fixtureinfo,
415+
keywords=item.keywords,
416+
originalname=item.originalname,
417+
)
418+
async_functions.append(async_function)
419+
hook_result.force_result(async_functions)
420+
421+
372422
_event_loop_fixture_id = StashKey[str]
373423

374424

0 commit comments

Comments
 (0)