Skip to content

Commit ee10388

Browse files
committed
[refactor] Added a factory method to PytestAsyncioFunction.
This avoids direct references to the subclasses of PytestAsyncioFunction, thus making the code more easily extendable with additional subclasses. Signed-off-by: Michael Seifert <[email protected]>
1 parent b8fe7ab commit ee10388

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

pytest_asyncio/plugin.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,21 @@ class PytestAsyncioFunction(pytest.Function):
357357
"""Base class for all test functions managed by pytest-asyncio."""
358358

359359
@classmethod
360-
def from_function(cls, function: pytest.Function, /) -> Self:
360+
def substitute(cls, item: pytest.Function, /) -> pytest.Function:
361+
"""
362+
Returns a PytestAsyncioFunction if there is an implementation that can handle
363+
the specified function item.
364+
365+
If no implementation of PytestAsyncioFunction can handle the specified item,
366+
the item is returned unchanged.
367+
"""
368+
for subclass in cls.__subclasses__():
369+
if subclass._can_substitute(item):
370+
return subclass._from_function(item)
371+
return item
372+
373+
@classmethod
374+
def _from_function(cls, function: pytest.Function, /) -> Self:
361375
"""
362376
Instantiates this specific PytestAsyncioFunction type from the specified
363377
pytest.Function item.
@@ -373,7 +387,7 @@ def from_function(cls, function: pytest.Function, /) -> Self:
373387
)
374388

375389
@staticmethod
376-
def can_substitute(item: pytest.Function) -> bool:
390+
def _can_substitute(item: pytest.Function) -> bool:
377391
"""Returns whether the specified function can be replaced by this class"""
378392
raise NotImplementedError()
379393

@@ -382,7 +396,7 @@ class AsyncFunction(PytestAsyncioFunction):
382396
"""Pytest item that is a coroutine or an asynchronous generator"""
383397

384398
@staticmethod
385-
def can_substitute(item: pytest.Function) -> bool:
399+
def _can_substitute(item: pytest.Function) -> bool:
386400
func = item.obj
387401
return _is_coroutine_or_asyncgen(func)
388402

@@ -402,7 +416,7 @@ class AsyncStaticMethod(PytestAsyncioFunction):
402416
"""
403417

404418
@staticmethod
405-
def can_substitute(item: pytest.Function) -> bool:
419+
def _can_substitute(item: pytest.Function) -> bool:
406420
func = item.obj
407421
return isinstance(func, staticmethod) and _is_coroutine_or_asyncgen(
408422
func.__func__
@@ -424,7 +438,7 @@ class AsyncHypothesisTest(PytestAsyncioFunction):
424438
"""
425439

426440
@staticmethod
427-
def can_substitute(item: pytest.Function) -> bool:
441+
def _can_substitute(item: pytest.Function) -> bool:
428442
func = item.obj
429443
return _is_hypothesis_test(func) and _hypothesis_test_wraps_coroutine(func)
430444

@@ -480,12 +494,7 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
480494
for node in node_iterator:
481495
updated_item = node
482496
if isinstance(node, pytest.Function):
483-
if AsyncStaticMethod.can_substitute(node):
484-
updated_item = AsyncStaticMethod.from_function(node)
485-
if AsyncFunction.can_substitute(node):
486-
updated_item = AsyncFunction.from_function(node)
487-
if AsyncHypothesisTest.can_substitute(node):
488-
updated_item = AsyncHypothesisTest.from_function(node)
497+
updated_item = PytestAsyncioFunction.substitute(node)
489498
updated_node_collection.append(updated_item)
490499

491500
hook_result.force_result(updated_node_collection)

0 commit comments

Comments
 (0)