Skip to content

Commit 2b541f5

Browse files
committed
[refactor] Moved logic to check whether a subclass of pytest.Function can substitute the Function item into the respective subclasses.
This change allows all subclasses of function items to be treated the same when modifying pytest items. Signed-off-by: Michael Seifert <[email protected]>
1 parent d5d4960 commit 2b541f5

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

pytest_asyncio/plugin.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,12 @@ async def setup():
356356
class AsyncFunction(pytest.Function):
357357
"""Pytest item that is a coroutine or an asynchronous generator"""
358358

359+
@staticmethod
360+
def can_substitute(item: pytest.Function) -> bool:
361+
"""Returns whether the specified function can be replaced by this class"""
362+
func = item.obj
363+
return _is_coroutine_or_asyncgen(func)
364+
359365
@classmethod
360366
def from_function(cls, function: pytest.Function, /) -> Self:
361367
"""
@@ -386,6 +392,14 @@ class AsyncStaticMethod(pytest.Function):
386392
decorated with staticmethod
387393
"""
388394

395+
@staticmethod
396+
def can_substitute(item: pytest.Function) -> bool:
397+
"""Returns whether the specified function can be replaced by this class"""
398+
func = item.obj
399+
return isinstance(func, staticmethod) and _is_coroutine_or_asyncgen(
400+
func.__func__
401+
)
402+
389403
@classmethod
390404
def from_function(cls, function: pytest.Function, /) -> Self:
391405
"""
@@ -416,6 +430,12 @@ class AsyncHypothesisTest(pytest.Function):
416430
@hypothesis.given.
417431
"""
418432

433+
@staticmethod
434+
def can_substitute(item: pytest.Function) -> bool:
435+
"""Returns whether the specified function can be replaced by this class"""
436+
func = item.obj
437+
return _is_hypothesis_test(func) and _hypothesis_test_wraps_coroutine(func)
438+
419439
@classmethod
420440
def from_function(cls, function: pytest.Function, /) -> Self:
421441
"""
@@ -483,13 +503,11 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
483503
for node in node_iterator:
484504
updated_item = node
485505
if isinstance(node, pytest.Function):
486-
if isinstance(obj, staticmethod) and _is_coroutine_or_asyncgen(
487-
obj.__func__
488-
):
506+
if AsyncStaticMethod.can_substitute(node):
489507
updated_item = AsyncStaticMethod.from_function(node)
490-
if _is_coroutine_or_asyncgen(obj):
508+
if AsyncFunction.can_substitute(node):
491509
updated_item = AsyncFunction.from_function(node)
492-
if _is_hypothesis_test(obj) and _hypothesis_test_wraps_coroutine(obj):
510+
if AsyncHypothesisTest.can_substitute(node):
493511
updated_item = AsyncHypothesisTest.from_function(node)
494512
updated_node_collection.append(updated_item)
495513

0 commit comments

Comments
 (0)