49
49
PytestPluginManager ,
50
50
)
51
51
52
+ from typing import Callable
52
53
if sys .version_info >= (3 , 10 ):
53
54
from typing import ParamSpec
54
55
else :
@@ -116,6 +117,7 @@ def fixture(
116
117
* ,
117
118
scope : _ScopeName | Callable [[str , Config ], _ScopeName ] = ...,
118
119
loop_scope : _ScopeName | None = ...,
120
+ loop_factory : _ScopeName | Callable [[], AbstractEventLoop ] = ...,
119
121
params : Iterable [object ] | None = ...,
120
122
autouse : bool = ...,
121
123
ids : (
@@ -133,6 +135,7 @@ def fixture(
133
135
* ,
134
136
scope : _ScopeName | Callable [[str , Config ], _ScopeName ] = ...,
135
137
loop_scope : _ScopeName | None = ...,
138
+ loop_factory : _ScopeName | Callable [[], AbstractEventLoop ] = ...,
136
139
params : Iterable [object ] | None = ...,
137
140
autouse : bool = ...,
138
141
ids : (
@@ -147,20 +150,21 @@ def fixture(
147
150
def fixture (
148
151
fixture_function : FixtureFunction [_P , _R ] | None = None ,
149
152
loop_scope : _ScopeName | None = None ,
153
+ loop_factory : _ScopeName | Callable [[], AbstractEventLoop ] = ...,
150
154
** kwargs : Any ,
151
155
) -> (
152
156
FixtureFunction [_P , _R ]
153
157
| Callable [[FixtureFunction [_P , _R ]], FixtureFunction [_P , _R ]]
154
158
):
155
159
if fixture_function is not None :
156
- _make_asyncio_fixture_function (fixture_function , loop_scope )
160
+ _make_asyncio_fixture_function (fixture_function , loop_scope , loop_factory )
157
161
return pytest .fixture (fixture_function , ** kwargs )
158
162
159
163
else :
160
164
161
165
@functools .wraps (fixture )
162
166
def inner (fixture_function : FixtureFunction [_P , _R ]) -> FixtureFunction [_P , _R ]:
163
- return fixture (fixture_function , loop_scope = loop_scope , ** kwargs )
167
+ return fixture (fixture_function , loop_factory = loop_factory , loop_scope = loop_scope , ** kwargs )
164
168
165
169
return inner
166
170
@@ -170,12 +174,13 @@ def _is_asyncio_fixture_function(obj: Any) -> bool:
170
174
return getattr (obj , "_force_asyncio_fixture" , False )
171
175
172
176
173
- def _make_asyncio_fixture_function (obj : Any , loop_scope : _ScopeName | None ) -> None :
177
+ def _make_asyncio_fixture_function (obj : Any , loop_scope : _ScopeName | None , loop_factory : _ScopeName | None ) -> None :
174
178
if hasattr (obj , "__func__" ):
175
179
# instance method, check the function object
176
180
obj = obj .__func__
177
181
obj ._force_asyncio_fixture = True
178
182
obj ._loop_scope = loop_scope
183
+ obj ._loop_factory = loop_factory
179
184
180
185
181
186
def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
@@ -234,14 +239,14 @@ def pytest_report_header(config: Config) -> list[str]:
234
239
235
240
236
241
def _fixture_synchronizer (
237
- fixturedef : FixtureDef , runner : Runner , request : FixtureRequest
242
+ fixturedef : FixtureDef , runner : Runner , request : FixtureRequest , loop_factory : Callable [[], AbstractEventLoop ]
238
243
) -> Callable :
239
244
"""Returns a synchronous function evaluating the specified fixture."""
240
245
fixture_function = resolve_fixture_function (fixturedef , request )
241
246
if inspect .isasyncgenfunction (fixturedef .func ):
242
- return _wrap_asyncgen_fixture (fixture_function , runner , request ) # type: ignore[arg-type]
247
+ return _wrap_asyncgen_fixture (fixture_function , runner , request , loop_factory ) # type: ignore[arg-type]
243
248
elif inspect .iscoroutinefunction (fixturedef .func ):
244
- return _wrap_async_fixture (fixture_function , runner , request ) # type: ignore[arg-type]
249
+ return _wrap_async_fixture (fixture_function , runner , request , loop_factory ) # type: ignore[arg-type]
245
250
else :
246
251
return fixturedef .func
247
252
@@ -256,6 +261,7 @@ def _wrap_asyncgen_fixture(
256
261
],
257
262
runner : Runner ,
258
263
request : FixtureRequest ,
264
+ loop_factory :Callable [[], AbstractEventLoop ]
259
265
) -> Callable [AsyncGenFixtureParams , AsyncGenFixtureYieldType ]:
260
266
@functools .wraps (fixture_function )
261
267
def _asyncgen_fixture_wrapper (
@@ -285,6 +291,9 @@ async def async_finalizer() -> None:
285
291
msg = "Async generator fixture didn't stop."
286
292
msg += "Yield only once."
287
293
raise ValueError (msg )
294
+ if loop_factory :
295
+ _loop = loop_factory ()
296
+ asyncio .set_event_loop (_loop )
288
297
289
298
runner .run (async_finalizer (), context = context )
290
299
if reset_contextvars is not None :
@@ -306,6 +315,7 @@ def _wrap_async_fixture(
306
315
],
307
316
runner : Runner ,
308
317
request : FixtureRequest ,
318
+ loop_factory : Callable [[], AbstractEventLoop ] | None = None
309
319
) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
310
320
311
321
@functools .wraps (fixture_function ) # type: ignore[arg-type]
@@ -318,8 +328,12 @@ async def setup():
318
328
return res
319
329
320
330
context = contextvars .copy_context ()
321
- result = runner .run (setup (), context = context )
322
331
332
+ # ensure loop_factory gets ran before we start running...
333
+ if loop_factory :
334
+ asyncio .set_event_loop (loop_factory ())
335
+
336
+ result = runner .run (setup (), context = context )
323
337
# Copy the context vars modified by the setup task into the current
324
338
# context, and (if needed) add a finalizer to reset them.
325
339
#
@@ -372,8 +386,6 @@ def restore_contextvars():
372
386
class PytestAsyncioFunction (Function ):
373
387
"""Base class for all test functions managed by pytest-asyncio."""
374
388
375
- loop_factory : Callable [[], AbstractEventLoop ] | None
376
-
377
389
@classmethod
378
390
def item_subclass_for (cls , item : Function , / ) -> type [PytestAsyncioFunction ] | None :
379
391
"""
@@ -388,18 +400,12 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
388
400
return None
389
401
390
402
@classmethod
391
- def _from_function (
392
- cls ,
393
- function : Function ,
394
- loop_factory : Callable [[], AbstractEventLoop ] | None = None ,
395
- / ,
396
- ) -> Function :
403
+ def _from_function (cls , function : Function , / ) -> Function :
397
404
"""
398
405
Instantiates this specific PytestAsyncioFunction type from the specified
399
406
Function item.
400
407
"""
401
408
assert function .get_closest_marker ("asyncio" )
402
-
403
409
subclass_instance = cls .from_parent (
404
410
function .parent ,
405
411
name = function .name ,
@@ -409,7 +415,6 @@ def _from_function(
409
415
keywords = function .keywords ,
410
416
originalname = function .originalname ,
411
417
)
412
- subclass_instance .loop_factory = loop_factory
413
418
subclass_instance .own_markers = function .own_markers
414
419
assert subclass_instance .own_markers == function .own_markers
415
420
return subclass_instance
@@ -429,7 +434,8 @@ def _can_substitute(item: Function) -> bool:
429
434
return inspect .iscoroutinefunction (func )
430
435
431
436
def runtest (self ) -> None :
432
- synchronized_obj = wrap_in_sync (self .obj )
437
+ # print(self.obj.pytestmark[0].__dict__)
438
+ synchronized_obj = wrap_in_sync (self .obj , self .obj .pytestmark [0 ].kwargs .get ('loop_factory' , None ))
433
439
with MonkeyPatch .context () as c :
434
440
c .setattr (self , "obj" , synchronized_obj )
435
441
super ().runtest ()
@@ -534,27 +540,9 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
534
540
node .config
535
541
) == Mode .AUTO and not node .get_closest_marker ("asyncio" ):
536
542
node .add_marker ("asyncio" )
537
- if asyncio_marker := node .get_closest_marker ("asyncio" ):
538
- if loop_factory := asyncio_marker .kwargs .get ("loop_factory" , None ):
539
- # multiply if loop_factory is an iterable object of factories
540
- if hasattr (loop_factory , "__iter__" ):
541
- updated_item = [
542
- specialized_item_class ._from_function (node , lf )
543
- for lf in loop_factory
544
- ]
545
- else :
546
- updated_item = specialized_item_class ._from_function (
547
- node , loop_factory
548
- )
549
- else :
550
- updated_item = specialized_item_class ._from_function (node )
551
-
552
- # we could have multiple factroies to test if so,
553
- # multiply the number of functions for us...
554
- if isinstance (updated_item , list ):
555
- updated_node_collection .extend (updated_item )
556
- else :
557
- updated_node_collection .append (updated_item )
543
+ if node .get_closest_marker ("asyncio" ):
544
+ updated_item = specialized_item_class ._from_function (node )
545
+ updated_node_collection .append (updated_item )
558
546
hook_result .force_result (updated_node_collection )
559
547
560
548
@@ -654,43 +642,46 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
654
642
655
643
def wrap_in_sync (
656
644
func : Callable [..., Awaitable [Any ]],
645
+ loop_factory :Callable [[], AbstractEventLoop ] | None = None
657
646
):
658
647
"""
659
648
Return a sync wrapper around an async function executing it in the
660
649
current event loop.
661
650
"""
662
-
663
651
@functools .wraps (func )
664
652
def inner (* args , ** kwargs ):
665
- coro = func (* args , ** kwargs )
666
- _loop = _get_event_loop_no_warn ()
667
- task = asyncio .ensure_future (coro , loop = _loop )
653
+ _last_loop = asyncio .get_event_loop ()
654
+ if loop_factory :
655
+ _loop = loop_factory ()
656
+ asyncio .set_event_loop (_loop )
657
+ else :
658
+ _loop = asyncio .get_event_loop ()
659
+ task = asyncio .ensure_future (func (* args , ** kwargs ), loop = _loop )
668
660
try :
669
661
_loop .run_until_complete (task )
670
662
except BaseException :
663
+
671
664
# run_until_complete doesn't get the result from exceptions
672
665
# that are not subclasses of `Exception`. Consume all
673
666
# exceptions to prevent asyncio's warning from logging.
674
667
if task .done () and not task .cancelled ():
675
668
task .exception ()
676
669
raise
677
670
671
+ asyncio .set_event_loop (_last_loop )
678
672
return inner
679
673
680
674
681
675
def pytest_runtest_setup (item : pytest .Item ) -> None :
682
676
marker = item .get_closest_marker ("asyncio" )
683
677
if marker is None :
684
678
return
685
- getattr (marker , "loop_factory" , None )
686
679
default_loop_scope = _get_default_test_loop_scope (item .config )
687
680
loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
688
681
runner_fixture_id = f"_{ loop_scope } _scoped_runner"
689
- fixturenames : list [str ] = item .fixturenames # type: ignore[attr-defined]
690
-
682
+ fixturenames = item .fixturenames # type: ignore[attr-defined]
691
683
if runner_fixture_id not in fixturenames :
692
684
fixturenames .append (runner_fixture_id )
693
-
694
685
obj = getattr (item , "obj" , None )
695
686
if not getattr (obj , "hypothesis" , False ) and getattr (
696
687
obj , "is_hypothesis_test" , False
@@ -717,17 +708,12 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
717
708
or default_loop_scope
718
709
or fixturedef .scope
719
710
)
720
- # XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
721
711
loop_factory = getattr (fixturedef .func , "loop_factory" , None )
722
712
723
- print (f"LOOP FACTORY: { loop_factory } { fixturedef .func } " )
724
- sys .stdout .flush ()
725
-
726
713
runner_fixture_id = f"_{ loop_scope } _scoped_runner"
727
- runner : Runner = request .getfixturevalue (runner_fixture_id )
728
-
729
- synchronizer = _fixture_synchronizer (fixturedef , runner , request )
730
- _make_asyncio_fixture_function (synchronizer , loop_scope )
714
+ runner = request .getfixturevalue (runner_fixture_id )
715
+ synchronizer = _fixture_synchronizer (fixturedef , runner , request , loop_factory )
716
+ _make_asyncio_fixture_function (synchronizer , loop_scope , loop_factory )
731
717
with MonkeyPatch .context () as c :
732
718
c .setattr (fixturedef , "func" , synchronizer )
733
719
hook_result = yield
@@ -750,12 +736,9 @@ def _get_marked_loop_scope(
750
736
) -> _ScopeName :
751
737
assert asyncio_marker .name == "asyncio"
752
738
if asyncio_marker .args or (
753
- asyncio_marker .kwargs
754
- and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" , "loop_factory" }
739
+ asyncio_marker .kwargs and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" , "loop_factory" }
755
740
):
756
- raise ValueError (
757
- "mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
758
- )
741
+ raise ValueError ("mark.asyncio accepts only a keyword arguments 'loop_scope' or 'loop_factory'" )
759
742
if "scope" in asyncio_marker .kwargs :
760
743
if "loop_scope" in asyncio_marker .kwargs :
761
744
raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
0 commit comments