Skip to content

Commit ee7bdce

Browse files
committed
Incomplete need to figure out how to get loop_factory / multiple into asyncio.Runner
1 parent c26f806 commit ee7bdce

File tree

3 files changed

+87
-33
lines changed

3 files changed

+87
-33
lines changed

pytest_asyncio/plugin.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ def restore_contextvars():
372372
class PytestAsyncioFunction(Function):
373373
"""Base class for all test functions managed by pytest-asyncio."""
374374

375+
loop_factory: Callable[[], AbstractEventLoop] | None
376+
375377
@classmethod
376378
def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | None:
377379
"""
@@ -386,12 +388,18 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
386388
return None
387389

388390
@classmethod
389-
def _from_function(cls, function: Function, /) -> Function:
391+
def _from_function(
392+
cls,
393+
function: Function,
394+
loop_factory: Callable[[], AbstractEventLoop] | None = None,
395+
/,
396+
) -> Function:
390397
"""
391398
Instantiates this specific PytestAsyncioFunction type from the specified
392399
Function item.
393400
"""
394401
assert function.get_closest_marker("asyncio")
402+
395403
subclass_instance = cls.from_parent(
396404
function.parent,
397405
name=function.name,
@@ -401,6 +409,7 @@ def _from_function(cls, function: Function, /) -> Function:
401409
keywords=function.keywords,
402410
originalname=function.originalname,
403411
)
412+
subclass_instance.loop_factory = loop_factory
404413
subclass_instance.own_markers = function.own_markers
405414
assert subclass_instance.own_markers == function.own_markers
406415
return subclass_instance
@@ -525,9 +534,27 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
525534
node.config
526535
) == Mode.AUTO and not node.get_closest_marker("asyncio"):
527536
node.add_marker("asyncio")
528-
if node.get_closest_marker("asyncio"):
529-
updated_item = specialized_item_class._from_function(node)
530-
updated_node_collection.append(updated_item)
537+
if asyncio_marker := node.get_closest_marker("asyncio"):
538+
if loop_factory := asyncio_marker.kwargs.get("loop_factory", None):
539+
# multiply if loop_factory is an iterable object of factories
540+
if hasattr(loop_factory, "__iter__"):
541+
updated_item = [
542+
specialized_item_class._from_function(node, lf)
543+
for lf in loop_factory
544+
]
545+
else:
546+
updated_item = specialized_item_class._from_function(
547+
node, loop_factory
548+
)
549+
else:
550+
updated_item = specialized_item_class._from_function(node)
551+
552+
# we could have multiple factroies to test if so,
553+
# multiply the number of functions for us...
554+
if isinstance(updated_item, list):
555+
updated_node_collection.extend(updated_item)
556+
else:
557+
updated_node_collection.append(updated_item)
531558
hook_result.force_result(updated_node_collection)
532559

533560

@@ -546,20 +573,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No
546573
_set_event_loop(old_loop)
547574

548575

549-
@contextlib.contextmanager
550-
def _temporary_event_loop(loop: AbstractEventLoop):
551-
try:
552-
old_event_loop = asyncio.get_event_loop()
553-
except RuntimeError:
554-
old_event_loop = None
555-
556-
asyncio.set_event_loop(old_event_loop)
557-
try:
558-
yield
559-
finally:
560-
asyncio.set_event_loop(old_event_loop)
561-
562-
563576
def _get_event_loop_policy() -> AbstractEventLoopPolicy:
564577
with warnings.catch_warnings():
565578
warnings.simplefilter("ignore", DeprecationWarning)
@@ -669,12 +682,15 @@ def pytest_runtest_setup(item: pytest.Item) -> None:
669682
marker = item.get_closest_marker("asyncio")
670683
if marker is None:
671684
return
685+
getattr(marker, "loop_factory", None)
672686
default_loop_scope = _get_default_test_loop_scope(item.config)
673687
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
674688
runner_fixture_id = f"_{loop_scope}_scoped_runner"
675-
fixturenames = item.fixturenames # type: ignore[attr-defined]
689+
fixturenames: list[str] = item.fixturenames # type: ignore[attr-defined]
690+
676691
if runner_fixture_id not in fixturenames:
677692
fixturenames.append(runner_fixture_id)
693+
678694
obj = getattr(item, "obj", None)
679695
if not getattr(obj, "hypothesis", False) and getattr(
680696
obj, "is_hypothesis_test", False
@@ -701,8 +717,15 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
701717
or default_loop_scope
702718
or fixturedef.scope
703719
)
720+
# XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
721+
loop_factory = getattr(fixturedef.func, "loop_factory", None)
722+
723+
print(f"LOOP FACTORY: {loop_factory} {fixturedef.func}")
724+
sys.stdout.flush()
725+
704726
runner_fixture_id = f"_{loop_scope}_scoped_runner"
705-
runner = request.getfixturevalue(runner_fixture_id)
727+
runner: Runner = request.getfixturevalue(runner_fixture_id)
728+
706729
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
707730
_make_asyncio_fixture_function(synchronizer, loop_scope)
708731
with MonkeyPatch.context() as c:
@@ -727,9 +750,12 @@ def _get_marked_loop_scope(
727750
) -> _ScopeName:
728751
assert asyncio_marker.name == "asyncio"
729752
if asyncio_marker.args or (
730-
asyncio_marker.kwargs and set(asyncio_marker.kwargs) - {"loop_scope", "scope"}
753+
asyncio_marker.kwargs
754+
and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
731755
):
732-
raise ValueError("mark.asyncio accepts only a keyword argument 'loop_scope'.")
756+
raise ValueError(
757+
"mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
758+
)
733759
if "scope" in asyncio_marker.kwargs:
734760
if "loop_scope" in asyncio_marker.kwargs:
735761
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)
@@ -795,12 +821,6 @@ def _scoped_runner(
795821
)
796822

797823

798-
@pytest.fixture(scope="session", autouse=True)
799-
def new_event_loop() -> AbstractEventLoop:
800-
"""Creates a new eventloop for different tests being ran"""
801-
return asyncio.new_event_loop()
802-
803-
804824
@pytest.fixture(scope="session", autouse=True)
805825
def event_loop_policy() -> AbstractEventLoopPolicy:
806826
"""Return an instance of the policy used to create asyncio event loops."""

tests/markers/test_invalid_arguments.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ async def test_anything():
4040
)
4141
result = pytester.runpytest_subprocess()
4242
result.assert_outcomes(errors=1)
43-
result.stdout.fnmatch_lines(
44-
["*ValueError: mark.asyncio accepts only a keyword argument*"]
45-
)
43+
result.stdout.fnmatch_lines([""])
4644

4745

4846
def test_error_when_wrong_keyword_argument_is_passed(
@@ -62,7 +60,9 @@ async def test_anything():
6260
result = pytester.runpytest_subprocess()
6361
result.assert_outcomes(errors=1)
6462
result.stdout.fnmatch_lines(
65-
["*ValueError: mark.asyncio accepts only a keyword argument 'loop_scope'*"]
63+
[
64+
"*ValueError: mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'*"
65+
]
6666
)
6767

6868

@@ -83,5 +83,7 @@ async def test_anything():
8383
result = pytester.runpytest_subprocess()
8484
result.assert_outcomes(errors=1)
8585
result.stdout.fnmatch_lines(
86-
["*ValueError: mark.asyncio accepts only a keyword argument*"]
86+
[
87+
"*ValueError: mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'*"
88+
]
8789
)

tests/test_asyncio_mark.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,35 @@ async def test_a(session_loop_fixture):
223223

224224
result = pytester.runpytest("--asyncio-mode=auto")
225225
result.assert_outcomes(passed=1)
226+
227+
228+
def test_asyncio_marker_event_loop_factories(pytester: Pytester):
229+
pytester.makeini(
230+
dedent(
231+
"""\
232+
[pytest]
233+
asyncio_default_fixture_loop_scope = function
234+
asyncio_default_test_loop_scope = module
235+
"""
236+
)
237+
)
238+
239+
pytester.makepyfile(
240+
dedent(
241+
"""\
242+
import asyncio
243+
import pytest_asyncio
244+
import pytest
245+
246+
class CustomEventLoop(asyncio.SelectorEventLoop):
247+
pass
248+
249+
@pytest.mark.asyncio(loop_factory=CustomEventLoop)
250+
async def test_has_different_event_loop():
251+
assert type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
252+
"""
253+
)
254+
)
255+
256+
result = pytester.runpytest("--asyncio-mode=auto")
257+
result.assert_outcomes(passed=1)

0 commit comments

Comments
 (0)