Skip to content

Commit 90d548a

Browse files
committed
[refactor] PytestAsyncioFunction.substitute returns the specialized subclass rather than the instance.
Previously, PytestAsyncioFunction.substitute returned the Item instance unchanged, when no substitution occured. This change allows for different code branches based on whether the substitution happened or not. Signed-off-by: Michael Seifert <[email protected]>
1 parent fac9092 commit 90d548a

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

pytest_asyncio/plugin.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Literal,
2222
Optional,
2323
Set,
24+
Type,
2425
TypeVar,
2526
Union,
2627
overload,
@@ -365,18 +366,19 @@ class PytestAsyncioFunction(Function):
365366
"""Base class for all test functions managed by pytest-asyncio."""
366367

367368
@classmethod
368-
def substitute(cls, item: Function, /) -> Function:
369+
def item_subclass_for(
370+
cls, item: Function, /
371+
) -> Union[Type["PytestAsyncioFunction"], None]:
369372
"""
370-
Returns a PytestAsyncioFunction if there is an implementation that can handle
371-
the specified function item.
373+
Returns a subclass of PytestAsyncioFunction if there is a specialized subclass
374+
for the specified function item.
372375
373-
If no implementation of PytestAsyncioFunction can handle the specified item,
374-
the item is returned unchanged.
376+
Return None if no specialized subclass exists for the specified item.
375377
"""
376378
for subclass in cls.__subclasses__():
377379
if subclass._can_substitute(item):
378-
return subclass._from_function(item)
379-
return item
380+
return subclass
381+
return None
380382

381383
@classmethod
382384
def _from_function(cls, function: Function, /) -> Function:
@@ -535,7 +537,9 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass(
535537
for node in node_iterator:
536538
updated_item = node
537539
if isinstance(node, Function):
538-
updated_item = PytestAsyncioFunction.substitute(node)
540+
specialized_item_class = PytestAsyncioFunction.item_subclass_for(node)
541+
if specialized_item_class:
542+
updated_item = specialized_item_class._from_function(node)
539543
updated_node_collection.append(updated_item)
540544

541545
hook_result.force_result(updated_node_collection)

0 commit comments

Comments
 (0)