Skip to content

Commit b8fe7ab

Browse files
committed
[refactor] Extracted a base class for all types of pytest-asyncio function items.
Signed-off-by: Michael Seifert <[email protected]>
1 parent 2b541f5 commit b8fe7ab

File tree

1 file changed

+22
-47
lines changed

1 file changed

+22
-47
lines changed

pytest_asyncio/plugin.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,14 @@ async def setup():
353353
fixturedef.func = _async_fixture_wrapper
354354

355355

356-
class AsyncFunction(pytest.Function):
357-
"""Pytest item that is a coroutine or an asynchronous generator"""
358-
359-
@staticmethod
360-
def can_substitute(item: pytest.Function) -> bool:
361-
"""Returns whether the specified function can be replaced by this class"""
362-
func = item.obj
363-
return _is_coroutine_or_asyncgen(func)
356+
class PytestAsyncioFunction(pytest.Function):
357+
"""Base class for all test functions managed by pytest-asyncio."""
364358

365359
@classmethod
366360
def from_function(cls, function: pytest.Function, /) -> Self:
367361
"""
368-
Instantiates an AsyncFunction from the specified pytest.Function item.
362+
Instantiates this specific PytestAsyncioFunction type from the specified
363+
pytest.Function item.
369364
"""
370365
return cls.from_parent(
371366
function.parent,
@@ -377,6 +372,20 @@ def from_function(cls, function: pytest.Function, /) -> Self:
377372
originalname=function.originalname,
378373
)
379374

375+
@staticmethod
376+
def can_substitute(item: pytest.Function) -> bool:
377+
"""Returns whether the specified function can be replaced by this class"""
378+
raise NotImplementedError()
379+
380+
381+
class AsyncFunction(PytestAsyncioFunction):
382+
"""Pytest item that is a coroutine or an asynchronous generator"""
383+
384+
@staticmethod
385+
def can_substitute(item: pytest.Function) -> bool:
386+
func = item.obj
387+
return _is_coroutine_or_asyncgen(func)
388+
380389
def runtest(self) -> None:
381390
if self.get_closest_marker("asyncio"):
382391
self.obj = wrap_in_sync(
@@ -386,35 +395,19 @@ def runtest(self) -> None:
386395
super().runtest()
387396

388397

389-
class AsyncStaticMethod(pytest.Function):
398+
class AsyncStaticMethod(PytestAsyncioFunction):
390399
"""
391400
Pytest item that is a coroutine or an asynchronous generator
392401
decorated with staticmethod
393402
"""
394403

395404
@staticmethod
396405
def can_substitute(item: pytest.Function) -> bool:
397-
"""Returns whether the specified function can be replaced by this class"""
398406
func = item.obj
399407
return isinstance(func, staticmethod) and _is_coroutine_or_asyncgen(
400408
func.__func__
401409
)
402410

403-
@classmethod
404-
def from_function(cls, function: pytest.Function, /) -> Self:
405-
"""
406-
Instantiates an AsyncStaticMethod from the specified pytest.Function item.
407-
"""
408-
return cls.from_parent(
409-
function.parent,
410-
name=function.name,
411-
callspec=getattr(function, "callspec", None),
412-
callobj=function.obj,
413-
fixtureinfo=function._fixtureinfo,
414-
keywords=function.keywords,
415-
originalname=function.originalname,
416-
)
417-
418411
def runtest(self) -> None:
419412
if self.get_closest_marker("asyncio"):
420413
self.obj = wrap_in_sync(
@@ -424,33 +417,17 @@ def runtest(self) -> None:
424417
super().runtest()
425418

426419

427-
class AsyncHypothesisTest(pytest.Function):
420+
class AsyncHypothesisTest(PytestAsyncioFunction):
428421
"""
429422
Pytest item that is coroutine or an asynchronous generator decorated by
430423
@hypothesis.given.
431424
"""
432425

433426
@staticmethod
434427
def can_substitute(item: pytest.Function) -> bool:
435-
"""Returns whether the specified function can be replaced by this class"""
436428
func = item.obj
437429
return _is_hypothesis_test(func) and _hypothesis_test_wraps_coroutine(func)
438430

439-
@classmethod
440-
def from_function(cls, function: pytest.Function, /) -> Self:
441-
"""
442-
Instantiates an AsyncFunction from the specified pytest.Function item.
443-
"""
444-
return cls.from_parent(
445-
function.parent,
446-
name=function.name,
447-
callspec=getattr(function, "callspec", None),
448-
callobj=function.obj,
449-
fixtureinfo=function._fixtureinfo,
450-
keywords=function.keywords,
451-
originalname=function.originalname,
452-
)
453-
454431
def runtest(self) -> None:
455432
if self.get_closest_marker("asyncio"):
456433
self.obj.hypothesis.inner_test = wrap_in_sync(
@@ -589,7 +566,7 @@ def pytest_collection_modifyitems(
589566
if _get_asyncio_mode(config) != Mode.AUTO:
590567
return
591568
for item in items:
592-
if isinstance(item, (AsyncFunction, AsyncHypothesisTest, AsyncStaticMethod)):
569+
if isinstance(item, PytestAsyncioFunction):
593570
item.add_marker("asyncio")
594571

595572

@@ -750,9 +727,7 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> Optional[object]:
750727
"""
751728
marker = pyfuncitem.get_closest_marker("asyncio")
752729
if marker is not None:
753-
if isinstance(
754-
pyfuncitem, (AsyncFunction, AsyncHypothesisTest, AsyncStaticMethod)
755-
):
730+
if isinstance(pyfuncitem, PytestAsyncioFunction):
756731
pass
757732
else:
758733
pyfuncitem.warn(

0 commit comments

Comments
 (0)