Skip to content

Commit e60c10e

Browse files
committed
refactor: Deduplicate runtest logic for PytestAsyncioFunction subclasses.
1 parent 24d0f51 commit e60c10e

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

pytest_asyncio/plugin.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -437,15 +437,6 @@ def _can_substitute(item: Function) -> bool:
437437
"""Returns whether the specified function can be replaced by this class"""
438438
raise NotImplementedError()
439439

440-
441-
class Coroutine(PytestAsyncioFunction):
442-
"""Pytest item created by a coroutine"""
443-
444-
@staticmethod
445-
def _can_substitute(item: Function) -> bool:
446-
func = item.obj
447-
return inspect.iscoroutinefunction(func)
448-
449440
def runtest(self) -> None:
450441
marker = self.get_closest_marker("asyncio")
451442
assert marker is not None
@@ -454,11 +445,33 @@ def runtest(self) -> None:
454445
runner_fixture_id = f"_{loop_scope}_scoped_runner"
455446
runner = self._request.getfixturevalue(runner_fixture_id)
456447
context = contextvars.copy_context()
457-
synchronized_obj = wrap_in_sync(self.obj, runner, context)
448+
synchronized_obj = wrap_in_sync(
449+
getattr(*self._synchronization_target_attr), runner, context
450+
)
458451
with MonkeyPatch.context() as c:
459-
c.setattr(self, "obj", synchronized_obj)
452+
c.setattr(*self._synchronization_target_attr, synchronized_obj)
460453
super().runtest()
461454

455+
@property
456+
def _synchronization_target_attr(self) -> tuple[object, str]:
457+
"""
458+
Return the coroutine that needs to be synchronized during the test run.
459+
460+
This method is inteded to be overwritten by subclasses when they need to apply
461+
the coroutine synchronizer to a value that's different from self.obj
462+
e.g. the AsyncHypothesisTest subclass.
463+
"""
464+
return self, "obj"
465+
466+
467+
class Coroutine(PytestAsyncioFunction):
468+
"""Pytest item created by a coroutine"""
469+
470+
@staticmethod
471+
def _can_substitute(item: Function) -> bool:
472+
func = item.obj
473+
return inspect.iscoroutinefunction(func)
474+
462475

463476
class AsyncGenerator(PytestAsyncioFunction):
464477
"""Pytest item created by an asynchronous generator"""
@@ -495,19 +508,6 @@ def _can_substitute(item: Function) -> bool:
495508
func.__func__
496509
)
497510

498-
def runtest(self) -> None:
499-
marker = self.get_closest_marker("asyncio")
500-
assert marker is not None
501-
default_loop_scope = _get_default_test_loop_scope(self.config)
502-
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
503-
runner_fixture_id = f"_{loop_scope}_scoped_runner"
504-
runner = self._request.getfixturevalue(runner_fixture_id)
505-
context = contextvars.copy_context()
506-
synchronized_obj = wrap_in_sync(self.obj, runner, context=context)
507-
with MonkeyPatch.context() as c:
508-
c.setattr(self, "obj", synchronized_obj)
509-
super().runtest()
510-
511511

512512
class AsyncHypothesisTest(PytestAsyncioFunction):
513513
"""
@@ -524,18 +524,9 @@ def _can_substitute(item: Function) -> bool:
524524
and inspect.iscoroutinefunction(func.hypothesis.inner_test)
525525
)
526526

527-
def runtest(self) -> None:
528-
marker = self.get_closest_marker("asyncio")
529-
assert marker is not None
530-
default_loop_scope = _get_default_test_loop_scope(self.config)
531-
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
532-
runner_fixture_id = f"_{loop_scope}_scoped_runner"
533-
runner = self._request.getfixturevalue(runner_fixture_id)
534-
context = contextvars.copy_context()
535-
synchronized_obj = wrap_in_sync(self.obj.hypothesis.inner_test, runner, context)
536-
with MonkeyPatch.context() as c:
537-
c.setattr(self.obj.hypothesis, "inner_test", synchronized_obj)
538-
super().runtest()
527+
@property
528+
def _synchronization_target_attr(self) -> tuple[object, str]:
529+
return self.obj.hypothesis, "inner_test"
539530

540531

541532
# The function name needs to start with "pytest_"

0 commit comments

Comments
 (0)