Skip to content

Commit 9643e2f

Browse files
committed
[refactor] Added static factory method to AsyncFunction which instantiates an AsyncFunction object from a pytest.Function.
Signed-off-by: Michael Seifert <[email protected]>
1 parent 6d06226 commit 9643e2f

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

pytest_asyncio/plugin.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Session,
3939
StashKey,
4040
)
41+
from typing_extensions import Self
4142

4243
_R = TypeVar("_R")
4344

@@ -356,6 +357,21 @@ async def setup():
356357
class AsyncFunction(pytest.Function):
357358
"""Pytest item that is a coroutine or an asynchronous generator"""
358359

360+
@classmethod
361+
def from_function(cls, function: pytest.Function, /) -> Self:
362+
"""
363+
Instantiates an AsyncFunction from the specified pytest.Function item.
364+
"""
365+
return cls.from_parent(
366+
function.parent,
367+
name=function.name,
368+
callspec=getattr(function, "callspec", None),
369+
callobj=function.obj,
370+
fixtureinfo=function._fixtureinfo,
371+
keywords=function.keywords,
372+
originalname=function.originalname,
373+
)
374+
359375

360376
_HOLDER: Set[FixtureDef] = set()
361377

@@ -396,27 +412,14 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
396412
except TypeError:
397413
# Treat single node as a single-element iterable
398414
node_iterator = iter((node_or_list_of_nodes,))
399-
async_functions = []
400-
for collector_or_item in node_iterator:
401-
if not (
402-
isinstance(collector_or_item, pytest.Function)
403-
and _is_coroutine_or_asyncgen(obj)
404-
):
405-
collector = collector_or_item
406-
async_functions.append(collector)
407-
continue
408-
item = collector_or_item
409-
async_function = AsyncFunction.from_parent(
410-
item.parent,
411-
name=item.name,
412-
callspec=getattr(item, "callspec", None),
413-
callobj=item.obj,
414-
fixtureinfo=item._fixtureinfo,
415-
keywords=item.keywords,
416-
originalname=item.originalname,
417-
)
418-
async_functions.append(async_function)
419-
hook_result.force_result(async_functions)
415+
updated_node_collection = []
416+
for node in node_iterator:
417+
if isinstance(node, pytest.Function) and _is_coroutine_or_asyncgen(obj):
418+
async_function = AsyncFunction.from_function(node)
419+
updated_node_collection.append(async_function)
420+
else:
421+
updated_node_collection.append(node)
422+
hook_result.force_result(updated_node_collection)
420423

421424

422425
_event_loop_fixture_id = StashKey[str]

0 commit comments

Comments
 (0)