@@ -353,11 +353,17 @@ async def setup():
353
353
fixturedef .func = _async_fixture_wrapper
354
354
355
355
356
+ class AsyncFunction (pytest .Function ):
357
+ """Pytest item that is a coroutine or an asynchronous generator"""
358
+
359
+
356
360
_HOLDER : Set [FixtureDef ] = set ()
357
361
358
362
359
- @pytest .hookimpl (tryfirst = True )
360
- def pytest_pycollect_makeitem (
363
+ # The function name needs to start with "pytest_"
364
+ # see https://github.com/pytest-dev/pytest/issues/11307
365
+ @pytest .hookimpl (specname = "pytest_pycollect_makeitem" , tryfirst = True )
366
+ def pytest_pycollect_makeitem_preprocess_async_fixtures (
361
367
collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
362
368
) -> Union [
363
369
pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]], None
@@ -369,6 +375,50 @@ def pytest_pycollect_makeitem(
369
375
return None
370
376
371
377
378
+ # The function name needs to start with "pytest_"
379
+ # see https://github.com/pytest-dev/pytest/issues/11307
380
+ @pytest .hookimpl (specname = "pytest_pycollect_makeitem" , hookwrapper = True )
381
+ def pytest_pycollect_makeitem_convert_async_functions_to_subclass (
382
+ collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
383
+ ) -> Union [
384
+ pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]], None
385
+ ]:
386
+ """
387
+ Converts coroutines and async generators collected as pytest.Functions
388
+ to AsyncFunction items.
389
+ """
390
+ hook_result = yield
391
+ node_or_list_of_nodes = hook_result .get_result ()
392
+ if not node_or_list_of_nodes :
393
+ return
394
+ try :
395
+ node_iterator = iter (node_or_list_of_nodes )
396
+ except TypeError :
397
+ # Treat single node as a single-element iterable
398
+ node_iterator = iter ((node_or_list_of_nodes ,))
399
+ async_functions = []
400
+ for collector_or_item in node_iterator :
401
+ if not (
402
+ isinstance (collector_or_item , pytest .Function )
403
+ and _is_coroutine_or_asyncgen (obj )
404
+ ):
405
+ collector = collector_or_item
406
+ async_functions .append (collector )
407
+ continue
408
+ item = collector_or_item
409
+ async_function = AsyncFunction .from_parent (
410
+ item .parent ,
411
+ name = item .name ,
412
+ callspec = getattr (item , "callspec" , None ),
413
+ callobj = item .obj ,
414
+ fixtureinfo = item ._fixtureinfo ,
415
+ keywords = item .keywords ,
416
+ originalname = item .originalname ,
417
+ )
418
+ async_functions .append (async_function )
419
+ hook_result .force_result (async_functions )
420
+
421
+
372
422
_event_loop_fixture_id = StashKey [str ]
373
423
374
424
0 commit comments