Skip to content

Commit a47fcb4

Browse files
Mihail Milushevlanzz
authored andcommitted
code review: kwarg-only match, replace recursive with depth
1 parent ab8f5ce commit a47fcb4

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

src/_pytest/_code/code.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -725,13 +725,22 @@ def _group_contains(
725725
exc_group: BaseExceptionGroup[BaseException],
726726
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
727727
match: Union[str, Pattern[str], None],
728-
recursive: bool = False,
728+
target_depth: Optional[int] = None,
729+
current_depth: int = 1,
729730
) -> bool:
730731
"""Return `True` if a `BaseExceptionGroup` contains a matching exception."""
732+
if (target_depth is not None) and (current_depth > target_depth):
733+
# already descended past the target depth
734+
return False
731735
for exc in exc_group.exceptions:
732-
if recursive and isinstance(exc, BaseExceptionGroup):
733-
if self._group_contains(exc, expected_exception, match, recursive):
736+
if isinstance(exc, BaseExceptionGroup):
737+
if self._group_contains(
738+
exc, expected_exception, match, target_depth, current_depth + 1
739+
):
734740
return True
741+
if (target_depth is not None) and (current_depth != target_depth):
742+
# not at the target depth, no match
743+
continue
735744
if not isinstance(exc, expected_exception):
736745
continue
737746
if match is not None:
@@ -744,8 +753,9 @@ def _group_contains(
744753
def group_contains(
745754
self,
746755
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
756+
*,
747757
match: Union[str, Pattern[str], None] = None,
748-
recursive: bool = False,
758+
depth: Optional[int] = None,
749759
) -> bool:
750760
"""Check whether a captured exception group contains a matching exception.
751761
@@ -762,13 +772,16 @@ def group_contains(
762772
To match a literal string that may contain :ref:`special characters
763773
<re-syntax>`, the pattern can first be escaped with :func:`re.escape`.
764774
765-
:param bool recursive:
766-
If `True`, search will descend recursively into any nested exception groups.
767-
If `False`, only the top exception group will be searched.
775+
:param Optional[int] depth:
776+
If `None`, will search for a matching exception at any nesting depth.
777+
If >= 1, will only match an exception if it's at the specified depth (depth = 1 being
778+
the exceptions contained within the topmost exception group).
768779
"""
769780
msg = "Captured exception is not an instance of `BaseExceptionGroup`"
770781
assert isinstance(self.value, BaseExceptionGroup), msg
771-
return self._group_contains(self.value, expected_exception, match, recursive)
782+
msg = "`depth` must be >= 1 if specified"
783+
assert (depth is None) or (depth >= 1), msg
784+
return self._group_contains(self.value, expected_exception, match, depth)
772785

773786

774787
@dataclasses.dataclass

testing/code/test_excinfo.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -472,35 +472,65 @@ def test_doesnt_contain_exception_match(self) -> None:
472472
raise exc_group
473473
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
474474

475-
def test_contains_exception_type_recursive(self) -> None:
475+
def test_contains_exception_type_unlimited_depth(self) -> None:
476476
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
477477
with pytest.raises(ExceptionGroup) as exc_info:
478478
raise exc_group
479-
assert exc_info.group_contains(RuntimeError, recursive=True)
479+
assert exc_info.group_contains(RuntimeError)
480480

481-
def test_doesnt_contain_exception_type_nonrecursive(self) -> None:
481+
def test_contains_exception_type_at_depth_1(self) -> None:
482+
exc_group = ExceptionGroup("", [RuntimeError()])
483+
with pytest.raises(ExceptionGroup) as exc_info:
484+
raise exc_group
485+
assert exc_info.group_contains(RuntimeError, depth=1)
486+
487+
def test_doesnt_contain_exception_type_past_depth(self) -> None:
482488
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
483489
with pytest.raises(ExceptionGroup) as exc_info:
484490
raise exc_group
485-
assert not exc_info.group_contains(RuntimeError)
491+
assert not exc_info.group_contains(RuntimeError, depth=1)
486492

487-
def test_contains_exception_match_recursive(self) -> None:
493+
def test_contains_exception_type_specific_depth(self) -> None:
494+
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
495+
with pytest.raises(ExceptionGroup) as exc_info:
496+
raise exc_group
497+
assert exc_info.group_contains(RuntimeError, depth=2)
498+
499+
def test_contains_exception_match_unlimited_depth(self) -> None:
488500
exc_group = ExceptionGroup(
489501
"", [ExceptionGroup("", [RuntimeError("exception message")])]
490502
)
503+
with pytest.raises(ExceptionGroup) as exc_info:
504+
raise exc_group
505+
assert exc_info.group_contains(RuntimeError, match=r"^exception message$")
506+
507+
def test_contains_exception_match_at_depth_1(self) -> None:
508+
exc_group = ExceptionGroup("", [RuntimeError("exception message")])
491509
with pytest.raises(ExceptionGroup) as exc_info:
492510
raise exc_group
493511
assert exc_info.group_contains(
494-
RuntimeError, match=r"^exception message$", recursive=True
512+
RuntimeError, match=r"^exception message$", depth=1
495513
)
496514

497-
def test_doesnt_contain_exception_match_nonrecursive(self) -> None:
515+
def test_doesnt_contain_exception_match_past_depth(self) -> None:
498516
exc_group = ExceptionGroup(
499-
"", [ExceptionGroup("", [RuntimeError("message that will not match")])]
517+
"", [ExceptionGroup("", [RuntimeError("exception message")])]
500518
)
501519
with pytest.raises(ExceptionGroup) as exc_info:
502520
raise exc_group
503-
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
521+
assert not exc_info.group_contains(
522+
RuntimeError, match=r"^exception message$", depth=1
523+
)
524+
525+
def test_contains_exception_match_specific_depth(self) -> None:
526+
exc_group = ExceptionGroup(
527+
"", [ExceptionGroup("", [RuntimeError("exception message")])]
528+
)
529+
with pytest.raises(ExceptionGroup) as exc_info:
530+
raise exc_group
531+
assert exc_info.group_contains(
532+
RuntimeError, match=r"^exception message$", depth=2
533+
)
504534

505535

506536
class TestFormattedExcinfo:

0 commit comments

Comments
 (0)