Skip to content

Commit edfbfef

Browse files
committed
figured out loop_factory :)
1 parent ee7bdce commit edfbfef

File tree

1 file changed

+43
-60
lines changed

1 file changed

+43
-60
lines changed

pytest_asyncio/plugin.py

Lines changed: 43 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
PytestPluginManager,
5050
)
5151

52+
from typing import Callable
5253
if sys.version_info >= (3, 10):
5354
from typing import ParamSpec
5455
else:
@@ -116,6 +117,7 @@ def fixture(
116117
*,
117118
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
118119
loop_scope: _ScopeName | None = ...,
120+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
119121
params: Iterable[object] | None = ...,
120122
autouse: bool = ...,
121123
ids: (
@@ -133,6 +135,7 @@ def fixture(
133135
*,
134136
scope: _ScopeName | Callable[[str, Config], _ScopeName] = ...,
135137
loop_scope: _ScopeName | None = ...,
138+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
136139
params: Iterable[object] | None = ...,
137140
autouse: bool = ...,
138141
ids: (
@@ -147,20 +150,21 @@ def fixture(
147150
def fixture(
148151
fixture_function: FixtureFunction[_P, _R] | None = None,
149152
loop_scope: _ScopeName | None = None,
153+
loop_factory: _ScopeName | Callable[[], AbstractEventLoop] = ...,
150154
**kwargs: Any,
151155
) -> (
152156
FixtureFunction[_P, _R]
153157
| Callable[[FixtureFunction[_P, _R]], FixtureFunction[_P, _R]]
154158
):
155159
if fixture_function is not None:
156-
_make_asyncio_fixture_function(fixture_function, loop_scope)
160+
_make_asyncio_fixture_function(fixture_function, loop_scope, loop_factory)
157161
return pytest.fixture(fixture_function, **kwargs)
158162

159163
else:
160164

161165
@functools.wraps(fixture)
162166
def inner(fixture_function: FixtureFunction[_P, _R]) -> FixtureFunction[_P, _R]:
163-
return fixture(fixture_function, loop_scope=loop_scope, **kwargs)
167+
return fixture(fixture_function, loop_factory=loop_factory, loop_scope=loop_scope, **kwargs)
164168

165169
return inner
166170

@@ -170,12 +174,13 @@ def _is_asyncio_fixture_function(obj: Any) -> bool:
170174
return getattr(obj, "_force_asyncio_fixture", False)
171175

172176

173-
def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None) -> None:
177+
def _make_asyncio_fixture_function(obj: Any, loop_scope: _ScopeName | None, loop_factory: _ScopeName | None) -> None:
174178
if hasattr(obj, "__func__"):
175179
# instance method, check the function object
176180
obj = obj.__func__
177181
obj._force_asyncio_fixture = True
178182
obj._loop_scope = loop_scope
183+
obj._loop_factory = loop_factory
179184

180185

181186
def _is_coroutine_or_asyncgen(obj: Any) -> bool:
@@ -234,14 +239,14 @@ def pytest_report_header(config: Config) -> list[str]:
234239

235240

236241
def _fixture_synchronizer(
237-
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest
242+
fixturedef: FixtureDef, runner: Runner, request: FixtureRequest, loop_factory: Callable[[], AbstractEventLoop]
238243
) -> Callable:
239244
"""Returns a synchronous function evaluating the specified fixture."""
240245
fixture_function = resolve_fixture_function(fixturedef, request)
241246
if inspect.isasyncgenfunction(fixturedef.func):
242-
return _wrap_asyncgen_fixture(fixture_function, runner, request) # type: ignore[arg-type]
247+
return _wrap_asyncgen_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
243248
elif inspect.iscoroutinefunction(fixturedef.func):
244-
return _wrap_async_fixture(fixture_function, runner, request) # type: ignore[arg-type]
249+
return _wrap_async_fixture(fixture_function, runner, request, loop_factory) # type: ignore[arg-type]
245250
else:
246251
return fixturedef.func
247252

@@ -256,6 +261,7 @@ def _wrap_asyncgen_fixture(
256261
],
257262
runner: Runner,
258263
request: FixtureRequest,
264+
loop_factory:Callable[[], AbstractEventLoop]
259265
) -> Callable[AsyncGenFixtureParams, AsyncGenFixtureYieldType]:
260266
@functools.wraps(fixture_function)
261267
def _asyncgen_fixture_wrapper(
@@ -285,6 +291,9 @@ async def async_finalizer() -> None:
285291
msg = "Async generator fixture didn't stop."
286292
msg += "Yield only once."
287293
raise ValueError(msg)
294+
if loop_factory:
295+
_loop = loop_factory()
296+
asyncio.set_event_loop(_loop)
288297

289298
runner.run(async_finalizer(), context=context)
290299
if reset_contextvars is not None:
@@ -306,6 +315,7 @@ def _wrap_async_fixture(
306315
],
307316
runner: Runner,
308317
request: FixtureRequest,
318+
loop_factory: Callable[[], AbstractEventLoop] | None = None
309319
) -> Callable[AsyncFixtureParams, AsyncFixtureReturnType]:
310320

311321
@functools.wraps(fixture_function) # type: ignore[arg-type]
@@ -318,8 +328,12 @@ async def setup():
318328
return res
319329

320330
context = contextvars.copy_context()
321-
result = runner.run(setup(), context=context)
322331

332+
# ensure loop_factory gets ran before we start running...
333+
if loop_factory:
334+
asyncio.set_event_loop(loop_factory())
335+
336+
result = runner.run(setup(), context=context)
323337
# Copy the context vars modified by the setup task into the current
324338
# context, and (if needed) add a finalizer to reset them.
325339
#
@@ -372,8 +386,6 @@ def restore_contextvars():
372386
class PytestAsyncioFunction(Function):
373387
"""Base class for all test functions managed by pytest-asyncio."""
374388

375-
loop_factory: Callable[[], AbstractEventLoop] | None
376-
377389
@classmethod
378390
def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | None:
379391
"""
@@ -388,18 +400,12 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
388400
return None
389401

390402
@classmethod
391-
def _from_function(
392-
cls,
393-
function: Function,
394-
loop_factory: Callable[[], AbstractEventLoop] | None = None,
395-
/,
396-
) -> Function:
403+
def _from_function(cls, function: Function, /) -> Function:
397404
"""
398405
Instantiates this specific PytestAsyncioFunction type from the specified
399406
Function item.
400407
"""
401408
assert function.get_closest_marker("asyncio")
402-
403409
subclass_instance = cls.from_parent(
404410
function.parent,
405411
name=function.name,
@@ -409,7 +415,6 @@ def _from_function(
409415
keywords=function.keywords,
410416
originalname=function.originalname,
411417
)
412-
subclass_instance.loop_factory = loop_factory
413418
subclass_instance.own_markers = function.own_markers
414419
assert subclass_instance.own_markers == function.own_markers
415420
return subclass_instance
@@ -429,7 +434,8 @@ def _can_substitute(item: Function) -> bool:
429434
return inspect.iscoroutinefunction(func)
430435

431436
def runtest(self) -> None:
432-
synchronized_obj = wrap_in_sync(self.obj)
437+
# print(self.obj.pytestmark[0].__dict__)
438+
synchronized_obj = wrap_in_sync(self.obj, self.obj.pytestmark[0].kwargs.get('loop_factory', None))
433439
with MonkeyPatch.context() as c:
434440
c.setattr(self, "obj", synchronized_obj)
435441
super().runtest()
@@ -534,27 +540,9 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
534540
node.config
535541
) == Mode.AUTO and not node.get_closest_marker("asyncio"):
536542
node.add_marker("asyncio")
537-
if asyncio_marker := node.get_closest_marker("asyncio"):
538-
if loop_factory := asyncio_marker.kwargs.get("loop_factory", None):
539-
# multiply if loop_factory is an iterable object of factories
540-
if hasattr(loop_factory, "__iter__"):
541-
updated_item = [
542-
specialized_item_class._from_function(node, lf)
543-
for lf in loop_factory
544-
]
545-
else:
546-
updated_item = specialized_item_class._from_function(
547-
node, loop_factory
548-
)
549-
else:
550-
updated_item = specialized_item_class._from_function(node)
551-
552-
# we could have multiple factroies to test if so,
553-
# multiply the number of functions for us...
554-
if isinstance(updated_item, list):
555-
updated_node_collection.extend(updated_item)
556-
else:
557-
updated_node_collection.append(updated_item)
543+
if node.get_closest_marker("asyncio"):
544+
updated_item = specialized_item_class._from_function(node)
545+
updated_node_collection.append(updated_item)
558546
hook_result.force_result(updated_node_collection)
559547

560548

@@ -654,43 +642,46 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
654642

655643
def wrap_in_sync(
656644
func: Callable[..., Awaitable[Any]],
645+
loop_factory:Callable[[], AbstractEventLoop] | None = None
657646
):
658647
"""
659648
Return a sync wrapper around an async function executing it in the
660649
current event loop.
661650
"""
662-
663651
@functools.wraps(func)
664652
def inner(*args, **kwargs):
665-
coro = func(*args, **kwargs)
666-
_loop = _get_event_loop_no_warn()
667-
task = asyncio.ensure_future(coro, loop=_loop)
653+
_last_loop = asyncio.get_event_loop()
654+
if loop_factory:
655+
_loop = loop_factory()
656+
asyncio.set_event_loop(_loop)
657+
else:
658+
_loop = asyncio.get_event_loop()
659+
task = asyncio.ensure_future(func(*args, **kwargs), loop=_loop)
668660
try:
669661
_loop.run_until_complete(task)
670662
except BaseException:
663+
671664
# run_until_complete doesn't get the result from exceptions
672665
# that are not subclasses of `Exception`. Consume all
673666
# exceptions to prevent asyncio's warning from logging.
674667
if task.done() and not task.cancelled():
675668
task.exception()
676669
raise
677670

671+
asyncio.set_event_loop(_last_loop)
678672
return inner
679673

680674

681675
def pytest_runtest_setup(item: pytest.Item) -> None:
682676
marker = item.get_closest_marker("asyncio")
683677
if marker is None:
684678
return
685-
getattr(marker, "loop_factory", None)
686679
default_loop_scope = _get_default_test_loop_scope(item.config)
687680
loop_scope = _get_marked_loop_scope(marker, default_loop_scope)
688681
runner_fixture_id = f"_{loop_scope}_scoped_runner"
689-
fixturenames: list[str] = item.fixturenames # type: ignore[attr-defined]
690-
682+
fixturenames = item.fixturenames # type: ignore[attr-defined]
691683
if runner_fixture_id not in fixturenames:
692684
fixturenames.append(runner_fixture_id)
693-
694685
obj = getattr(item, "obj", None)
695686
if not getattr(obj, "hypothesis", False) and getattr(
696687
obj, "is_hypothesis_test", False
@@ -717,17 +708,12 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
717708
or default_loop_scope
718709
or fixturedef.scope
719710
)
720-
# XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
721711
loop_factory = getattr(fixturedef.func, "loop_factory", None)
722712

723-
print(f"LOOP FACTORY: {loop_factory} {fixturedef.func}")
724-
sys.stdout.flush()
725-
726713
runner_fixture_id = f"_{loop_scope}_scoped_runner"
727-
runner: Runner = request.getfixturevalue(runner_fixture_id)
728-
729-
synchronizer = _fixture_synchronizer(fixturedef, runner, request)
730-
_make_asyncio_fixture_function(synchronizer, loop_scope)
714+
runner = request.getfixturevalue(runner_fixture_id)
715+
synchronizer = _fixture_synchronizer(fixturedef, runner, request, loop_factory)
716+
_make_asyncio_fixture_function(synchronizer, loop_scope, loop_factory)
731717
with MonkeyPatch.context() as c:
732718
c.setattr(fixturedef, "func", synchronizer)
733719
hook_result = yield
@@ -750,12 +736,9 @@ def _get_marked_loop_scope(
750736
) -> _ScopeName:
751737
assert asyncio_marker.name == "asyncio"
752738
if asyncio_marker.args or (
753-
asyncio_marker.kwargs
754-
and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
739+
asyncio_marker.kwargs and set(asyncio_marker.kwargs) - {"loop_scope", "scope", "loop_factory"}
755740
):
756-
raise ValueError(
757-
"mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
758-
)
741+
raise ValueError("mark.asyncio accepts only a keyword arguments 'loop_scope' or 'loop_factory'")
759742
if "scope" in asyncio_marker.kwargs:
760743
if "loop_scope" in asyncio_marker.kwargs:
761744
raise pytest.UsageError(_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR)

0 commit comments

Comments
 (0)