@@ -372,6 +372,8 @@ def restore_contextvars():
372
372
class PytestAsyncioFunction (Function ):
373
373
"""Base class for all test functions managed by pytest-asyncio."""
374
374
375
+ loop_factory : Callable [[], AbstractEventLoop ] | None
376
+
375
377
@classmethod
376
378
def item_subclass_for (cls , item : Function , / ) -> type [PytestAsyncioFunction ] | None :
377
379
"""
@@ -386,12 +388,18 @@ def item_subclass_for(cls, item: Function, /) -> type[PytestAsyncioFunction] | N
386
388
return None
387
389
388
390
@classmethod
389
- def _from_function (cls , function : Function , / ) -> Function :
391
+ def _from_function (
392
+ cls ,
393
+ function : Function ,
394
+ loop_factory : Callable [[], AbstractEventLoop ] | None = None ,
395
+ / ,
396
+ ) -> Function :
390
397
"""
391
398
Instantiates this specific PytestAsyncioFunction type from the specified
392
399
Function item.
393
400
"""
394
401
assert function .get_closest_marker ("asyncio" )
402
+
395
403
subclass_instance = cls .from_parent (
396
404
function .parent ,
397
405
name = function .name ,
@@ -401,6 +409,7 @@ def _from_function(cls, function: Function, /) -> Function:
401
409
keywords = function .keywords ,
402
410
originalname = function .originalname ,
403
411
)
412
+ subclass_instance .loop_factory = loop_factory
404
413
subclass_instance .own_markers = function .own_markers
405
414
assert subclass_instance .own_markers == function .own_markers
406
415
return subclass_instance
@@ -525,9 +534,27 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
525
534
node .config
526
535
) == Mode .AUTO and not node .get_closest_marker ("asyncio" ):
527
536
node .add_marker ("asyncio" )
528
- if node .get_closest_marker ("asyncio" ):
529
- updated_item = specialized_item_class ._from_function (node )
530
- updated_node_collection .append (updated_item )
537
+ if asyncio_marker := node .get_closest_marker ("asyncio" ):
538
+ if loop_factory := asyncio_marker .kwargs .get ("loop_factory" , None ):
539
+ # multiply if loop_factory is an iterable object of factories
540
+ if hasattr (loop_factory , "__iter__" ):
541
+ updated_item = [
542
+ specialized_item_class ._from_function (node , lf )
543
+ for lf in loop_factory
544
+ ]
545
+ else :
546
+ updated_item = specialized_item_class ._from_function (
547
+ node , loop_factory
548
+ )
549
+ else :
550
+ updated_item = specialized_item_class ._from_function (node )
551
+
552
+ # we could have multiple factroies to test if so,
553
+ # multiply the number of functions for us...
554
+ if isinstance (updated_item , list ):
555
+ updated_node_collection .extend (updated_item )
556
+ else :
557
+ updated_node_collection .append (updated_item )
531
558
hook_result .force_result (updated_node_collection )
532
559
533
560
@@ -546,20 +573,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No
546
573
_set_event_loop (old_loop )
547
574
548
575
549
- @contextlib .contextmanager
550
- def _temporary_event_loop (loop : AbstractEventLoop ):
551
- try :
552
- old_event_loop = asyncio .get_event_loop ()
553
- except RuntimeError :
554
- old_event_loop = None
555
-
556
- asyncio .set_event_loop (old_event_loop )
557
- try :
558
- yield
559
- finally :
560
- asyncio .set_event_loop (old_event_loop )
561
-
562
-
563
576
def _get_event_loop_policy () -> AbstractEventLoopPolicy :
564
577
with warnings .catch_warnings ():
565
578
warnings .simplefilter ("ignore" , DeprecationWarning )
@@ -669,12 +682,15 @@ def pytest_runtest_setup(item: pytest.Item) -> None:
669
682
marker = item .get_closest_marker ("asyncio" )
670
683
if marker is None :
671
684
return
685
+ getattr (marker , "loop_factory" , None )
672
686
default_loop_scope = _get_default_test_loop_scope (item .config )
673
687
loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
674
688
runner_fixture_id = f"_{ loop_scope } _scoped_runner"
675
- fixturenames = item .fixturenames # type: ignore[attr-defined]
689
+ fixturenames : list [str ] = item .fixturenames # type: ignore[attr-defined]
690
+
676
691
if runner_fixture_id not in fixturenames :
677
692
fixturenames .append (runner_fixture_id )
693
+
678
694
obj = getattr (item , "obj" , None )
679
695
if not getattr (obj , "hypothesis" , False ) and getattr (
680
696
obj , "is_hypothesis_test" , False
@@ -701,8 +717,15 @@ def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
701
717
or default_loop_scope
702
718
or fixturedef .scope
703
719
)
720
+ # XXX: Currently Confused as to where to debug and harvest and get the runner to use the loop_factory argument.
721
+ loop_factory = getattr (fixturedef .func , "loop_factory" , None )
722
+
723
+ print (f"LOOP FACTORY: { loop_factory } { fixturedef .func } " )
724
+ sys .stdout .flush ()
725
+
704
726
runner_fixture_id = f"_{ loop_scope } _scoped_runner"
705
- runner = request .getfixturevalue (runner_fixture_id )
727
+ runner : Runner = request .getfixturevalue (runner_fixture_id )
728
+
706
729
synchronizer = _fixture_synchronizer (fixturedef , runner , request )
707
730
_make_asyncio_fixture_function (synchronizer , loop_scope )
708
731
with MonkeyPatch .context () as c :
@@ -727,9 +750,12 @@ def _get_marked_loop_scope(
727
750
) -> _ScopeName :
728
751
assert asyncio_marker .name == "asyncio"
729
752
if asyncio_marker .args or (
730
- asyncio_marker .kwargs and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" }
753
+ asyncio_marker .kwargs
754
+ and set (asyncio_marker .kwargs ) - {"loop_scope" , "scope" , "loop_factory" }
731
755
):
732
- raise ValueError ("mark.asyncio accepts only a keyword argument 'loop_scope'." )
756
+ raise ValueError (
757
+ "mark.asyncio accepts only keyword arguments 'loop_scope', 'loop_factory'."
758
+ )
733
759
if "scope" in asyncio_marker .kwargs :
734
760
if "loop_scope" in asyncio_marker .kwargs :
735
761
raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
@@ -795,12 +821,6 @@ def _scoped_runner(
795
821
)
796
822
797
823
798
- @pytest .fixture (scope = "session" , autouse = True )
799
- def new_event_loop () -> AbstractEventLoop :
800
- """Creates a new eventloop for different tests being ran"""
801
- return asyncio .new_event_loop ()
802
-
803
-
804
824
@pytest .fixture (scope = "session" , autouse = True )
805
825
def event_loop_policy () -> AbstractEventLoopPolicy :
806
826
"""Return an instance of the policy used to create asyncio event loops."""
0 commit comments