Skip to content

Commit cd1144c

Browse files
committed
inject at the runner instead however there was a side-effect so I made a comment explaining it.
1 parent b68073a commit cd1144c

File tree

2 files changed

+58
-37
lines changed

2 files changed

+58
-37
lines changed

pytest_asyncio/plugin.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def fixture(
116116
*,
117117
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
118118
loop_scope: _ScopeName | None = ...,
119-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
119+
loop_factory: Callable[[], AbstractEventLoop] | None = ...,
120120
params: Iterable[object] | None = ...,
121121
autouse: bool = ...,
122122
ids: (
@@ -134,7 +134,7 @@ def fixture(
134134
*,
135135
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
136136
loop_scope: _ScopeName | None = ...,
137-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
137+
loop_factory: Callable[[], AbstractEventLoop] | None = ...,
138138
params: Iterable[object] | None = ...,
139139
autouse: bool = ...,
140140
ids: (
@@ -149,7 +149,7 @@ def fixture(
149149
def fixture(
150150
fixture_function: FixtureFunction[_P, _R] | None = None,
151151
loop_scope: _ScopeName | None = None,
152-
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
152+
loop_factory: Callable[[], AbstractEventLoop] | None = None,
153153
**kwargs: Any,
154154
) -> (
155155
FixtureFunction[_P, _R]
@@ -179,7 +179,9 @@ def _is_asyncio_fixture_function(obj: Any) -> bool:
179179

180180

181181
def _make_asyncio_fixture_function(
182-
obj: Any, loop_scope: _ScopeName | None, loop_factory: _ScopeName | None
182+
obj: Any,
183+
loop_scope: _ScopeName | None,
184+
loop_factory: Callable[[], AbstractEventLoop] | None,
183185
) -> None:
184186
if hasattr(obj, "__func__"):
185187
# instance method, check the function object
@@ -248,14 +250,13 @@ def _fixture_synchronizer(
248250
fixturedef: FixtureDef,
249251
runner: Runner,
250252
request: FixtureRequest,
251-
loop_factory: Callable[[], AbstractEventLoop],
252253
) -> Callable:
253254
"""Returns a synchronous function evaluating the specified fixture."""
254255
fixture_function = resolve_fixture_function(fixturedef, request)
255256
if inspect.isasyncgenfunction(fixturedef.func):
256-
return _wrap_asyncgen_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
257+
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
257258
elif inspect.iscoroutinefunction(fixturedef.func):
258-
return _wrap_async_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
259+
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
259260
else:
260261
return fixturedef.func
261262

@@ -270,7 +271,6 @@ def _wrap_asyncgen_fixture(
270271
],
271272
runner: Runner,
272273
request: FixtureRequest,
273-
loop_factory: Callable[[], AbstractEventLoop],
274274
) -> Callable[AsyncGenFixtureParams, AsyncGenFixtureYieldType]:
275275
@functools.wraps(fixture_function)
276276
def _asyncgen_fixture_wrapper(
@@ -301,10 +301,6 @@ async def async_finalizer() -> None:
301301
msg += "Yield only once."
302302
raise ValueError(msg)
303303

304-
if loop_factory:
305-
_loop = loop_factory()
306-
asyncio.set_event_loop(_loop)
307-
308304
runner.run(async_finalizer(), context=context)
309305
if reset_contextvars is not None:
310306
reset_contextvars()
@@ -325,9 +321,7 @@ def _wrap_async_fixture(
325321
],
326322
runner: Runner,
327323
request: FixtureRequest,
328-
loop_factory: Callable[[], AbstractEventLoop] | None = None,
329324
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
330-
331325
@functools.wraps(fixture_function) # type: ignore[arg-type]
332326
def _async_fixture_wrapper(
333327
*args: AsyncFixtureParams.args,
@@ -339,10 +333,6 @@ async def setup():
339333

340334
context = contextvars.copy_context()
341335

342-
# ensure loop_factory gets ran before we start running...
343-
if loop_factory:
344-
asyncio.set_event_loop(loop_factory())
345-
346336
result = runner.run(setup(), context=context)
347337
# Copy the context vars modified by the setup task into the current
348338
# context, and (if needed) add a finalizer to reset them.
@@ -445,9 +435,7 @@ def _can_substitute(item: Function) -> bool:
445435

446436
def runtest(self) -> None:
447437
# print(self.obj.pytestmark[0].__dict__)
448-
synchronized_obj = wrap_in_sync(
449-
self.obj, self.obj.pytestmark[0].kwargs.get("loop_factory", None)
450-
)
438+
synchronized_obj = wrap_in_sync(self.obj)
451439
with MonkeyPatch.context() as c:
452440
c.setattr(self, "obj", synchronized_obj)
453441
super().runtest()
@@ -559,16 +547,32 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
559547

560548

561549
@contextlib.contextmanager
562-
def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[None]:
550+
def _temporary_event_loop_policy(
551+
policy: AbstractEventLoopPolicy,
552+
loop_facotry: Callable[..., AbstractEventLoop] | None,
553+
) -> Iterator[None]:
554+
563555
old_loop_policy = _get_event_loop_policy()
564556
try:
565557
old_loop = _get_event_loop_no_warn()
566558
except RuntimeError:
567559
old_loop = None
560+
# XXX: For some reason this function can override runner's
561+
# _loop_factory (At least observed on backported versions of Runner)
562+
# so we need to re-override if existing...
563+
if loop_facotry:
564+
_loop = loop_facotry()
565+
_set_event_loop(_loop)
566+
else:
567+
_loop = None
568+
568569
_set_event_loop_policy(policy)
569570
try:
570571
yield
571572
finally:
573+
if _loop:
574+
# Do not let BaseEventLoop.__del__ complain!
575+
_loop.close()
572576
_set_event_loop_policy(old_loop_policy)
573577
_set_event_loop(old_loop)
574578

@@ -654,7 +658,6 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
654658

655659
def wrap_in_sync(
656660
func: Callable[..., Awaitable[Any]],
657-
loop_factory: Callable[[], AbstractEventLoop] | None = None,
658661
):
659662
"""
660663
Return a sync wrapper around an async function executing it in the
@@ -663,26 +666,18 @@ def wrap_in_sync(
663666

664667
@functools.wraps(func)
665668
def inner(*args, **kwargs):
666-
_last_loop = asyncio.get_event_loop()
667-
if loop_factory:
668-
_loop = loop_factory()
669-
asyncio.set_event_loop(_loop)
670-
else:
671-
_loop = asyncio.get_event_loop()
669+
_loop = asyncio.get_event_loop()
672670
task = asyncio.ensure_future(func(*args, **kwargs), loop=_loop)
673671
try:
674672
_loop.run_until_complete(task)
675673
except BaseException:
676-
677674
# run_until_complete doesn't get the result from exceptions
678675
# that are not subclasses of `Exception`. Consume all
679676
# exceptions to prevent asyncio's warning from logging.
680677
if task.done() and not task.cancelled():
681678
task.exception()
682679
raise
683680

684-
asyncio.set_event_loop(_last_loop)
685-
686681
return inner
687682

688683

@@ -726,7 +721,7 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
726721

727722
runner_fixture_id = f"_{loop_scope}_scoped_runner"
728723
runner = request.getfixturevalue(runner_fixture_id)
729-
synchronizer = _fixture_synchronizer(fixturedef, runner, request, loop_factory)
724+
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
730725
_make_asyncio_fixture_function(synchronizer, loop_scope, loop_factory)
731726
with MonkeyPatch.context() as c:
732727
c.setattr(fixturedef, "func", synchronizer)
@@ -754,7 +749,8 @@ def _get_marked_loop_scope(
754749
and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
755750
):
756751
raise ValueError(
757-
"mark.asyncio accepts only a keyword arguments 'loop_scope' or 'loop_factory'"
752+
"mark.asyncio accepts only a keyword arguments 'loop_scope'"
753+
" or 'loop_factory'"
758754
)
759755
if "scope" in asyncio_marker.kwargs:
760756
if "loop_scope" in asyncio_marker.kwargs:
@@ -784,17 +780,32 @@ def _get_default_test_loop_scope(config: Config) -> _ScopeName:
784780
"""
785781

786782

783+
def _get_loop_facotry(
784+
request: FixtureRequest,
785+
) -> Callable[[], AbstractEventLoop] | None:
786+
if asyncio_mark := request._pyfuncitem.get_closest_marker("asyncio"):
787+
factory = asyncio_mark.kwargs.get("loop_factory", None)
788+
print(f"FACTORY {factory}")
789+
return factory
790+
else:
791+
return request.obj.__dict__.get("_loop_factory", None) # type: ignore[attr-defined]
792+
793+
787794
def _create_scoped_runner_fixture(scope: _ScopeName) -> Callable:
788795
@pytest.fixture(
789796
scope=scope,
790797
name=f"_{scope}_scoped_runner",
791798
)
792799
def _scoped_runner(
793-
event_loop_policy,
800+
event_loop_policy: AbstractEventLoopPolicy, request: FixtureRequest
794801
) -> Iterator[Runner]:
795802
new_loop_policy = event_loop_policy
796-
with _temporary_event_loop_policy(new_loop_policy):
797-
runner = Runner().__enter__()
803+
804+
# We need to get the factory now because
805+
# _temporary_event_loop_policy can override the Runner
806+
factory = _get_loop_facotry(request)
807+
with _temporary_event_loop_policy(new_loop_policy, factory):
808+
runner = Runner(loop_factory=factory).__enter__()
798809
try:
799810
yield runner
800811
except Exception as e:

tests/test_asyncio_mark.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ class CustomEventLoop(asyncio.SelectorEventLoop):
249249
@pytest.mark.asyncio(loop_factory=CustomEventLoop)
250250
async def test_has_different_event_loop():
251251
assert type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
252+
253+
@pytest_asyncio.fixture(loop_factory=CustomEventLoop)
254+
async def custom_fixture():
255+
yield asyncio.get_running_loop()
256+
257+
async def test_with_fixture(custom_fixture):
258+
# Both of these should be the same...
259+
type(asyncio.get_running_loop()).__name__ == "CustomEventLoop"
260+
type(custom_fixture).__name__ == "CustomEventLoop"
261+
252262
"""
253263
)
254264
)

0 commit comments

Comments
 (0)