Skip to content

Commit ab8f5ce

Browse files
committed
Add new ExceptionInfo.group_contains assertion helper method
Tests if a captured exception group contains an expected exception. Will raise `AssertionError` if the wrapped exception is not an exception group. Supports recursive search into nested exception groups.
1 parent 6c2feb7 commit ab8f5ce

File tree

4 files changed

+121
-6
lines changed

4 files changed

+121
-6
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ Michal Wajszczuk
266266
Michał Zięba
267267
Mickey Pashov
268268
Mihai Capotă
269+
Mihail Milushev
269270
Mike Hoyle (hoylemd)
270271
Mike Lundy
271272
Milan Lesnek

changelog/10441.feature.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added :func:`ExceptionInfo.group_contains() <pytest.ExceptionInfo.group_contains>`, an assertion
2+
helper that tests if an `ExceptionGroup` contains a matching exception.

src/_pytest/_code/code.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,26 +697,79 @@ def getrepr(
697697
)
698698
return fmt.repr_excinfo(self)
699699

700+
def _stringify_exception(self, exc: BaseException) -> str:
701+
return "\n".join(
702+
[
703+
str(exc),
704+
*getattr(exc, "__notes__", []),
705+
]
706+
)
707+
700708
def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]":
701709
"""Check whether the regular expression `regexp` matches the string
702710
representation of the exception using :func:`python:re.search`.
703711
704712
If it matches `True` is returned, otherwise an `AssertionError` is raised.
705713
"""
706714
__tracebackhide__ = True
707-
value = "\n".join(
708-
[
709-
str(self.value),
710-
*getattr(self.value, "__notes__", []),
711-
]
712-
)
715+
value = self._stringify_exception(self.value)
713716
msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}"
714717
if regexp == value:
715718
msg += "\n Did you mean to `re.escape()` the regex?"
716719
assert re.search(regexp, value), msg
717720
# Return True to allow for "assert excinfo.match()".
718721
return True
719722

723+
def _group_contains(
724+
self,
725+
exc_group: BaseExceptionGroup[BaseException],
726+
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
727+
match: Union[str, Pattern[str], None],
728+
recursive: bool = False,
729+
) -> bool:
730+
"""Return `True` if a `BaseExceptionGroup` contains a matching exception."""
731+
for exc in exc_group.exceptions:
732+
if recursive and isinstance(exc, BaseExceptionGroup):
733+
if self._group_contains(exc, expected_exception, match, recursive):
734+
return True
735+
if not isinstance(exc, expected_exception):
736+
continue
737+
if match is not None:
738+
value = self._stringify_exception(exc)
739+
if not re.search(match, value):
740+
continue
741+
return True
742+
return False
743+
744+
def group_contains(
745+
self,
746+
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
747+
match: Union[str, Pattern[str], None] = None,
748+
recursive: bool = False,
749+
) -> bool:
750+
"""Check whether a captured exception group contains a matching exception.
751+
752+
:param Type[BaseException] | Tuple[Type[BaseException]] expected_exception:
753+
The expected exception type, or a tuple if one of multiple possible
754+
exception types are expected.
755+
756+
:param str | Pattern[str] | None match:
757+
If specified, a string containing a regular expression,
758+
or a regular expression object, that is tested against the string
759+
representation of the exception and its `PEP-678 <https://peps.python.org/pep-0678/>` `__notes__`
760+
using :func:`re.search`.
761+
762+
To match a literal string that may contain :ref:`special characters
763+
<re-syntax>`, the pattern can first be escaped with :func:`re.escape`.
764+
765+
:param bool recursive:
766+
If `True`, search will descend recursively into any nested exception groups.
767+
If `False`, only the top exception group will be searched.
768+
"""
769+
msg = "Captured exception is not an instance of `BaseExceptionGroup`"
770+
assert isinstance(self.value, BaseExceptionGroup), msg
771+
return self._group_contains(self.value, expected_exception, match, recursive)
772+
720773

721774
@dataclasses.dataclass
722775
class FormattedExcinfo:

testing/code/test_excinfo.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
if TYPE_CHECKING:
2828
from _pytest._code.code import _TracebackStyle
2929

30+
if sys.version_info[:2] < (3, 11):
31+
from exceptiongroup import ExceptionGroup
32+
3033

3134
@pytest.fixture
3235
def limited_recursion_depth():
@@ -444,6 +447,62 @@ def test_division_zero():
444447
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])
445448

446449

450+
class TestGroupContains:
451+
def test_contains_exception_type(self) -> None:
452+
exc_group = ExceptionGroup("", [RuntimeError()])
453+
with pytest.raises(ExceptionGroup) as exc_info:
454+
raise exc_group
455+
assert exc_info.group_contains(RuntimeError)
456+
457+
def test_doesnt_contain_exception_type(self) -> None:
458+
exc_group = ExceptionGroup("", [ValueError()])
459+
with pytest.raises(ExceptionGroup) as exc_info:
460+
raise exc_group
461+
assert not exc_info.group_contains(RuntimeError)
462+
463+
def test_contains_exception_match(self) -> None:
464+
exc_group = ExceptionGroup("", [RuntimeError("exception message")])
465+
with pytest.raises(ExceptionGroup) as exc_info:
466+
raise exc_group
467+
assert exc_info.group_contains(RuntimeError, match=r"^exception message$")
468+
469+
def test_doesnt_contain_exception_match(self) -> None:
470+
exc_group = ExceptionGroup("", [RuntimeError("message that will not match")])
471+
with pytest.raises(ExceptionGroup) as exc_info:
472+
raise exc_group
473+
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
474+
475+
def test_contains_exception_type_recursive(self) -> None:
476+
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
477+
with pytest.raises(ExceptionGroup) as exc_info:
478+
raise exc_group
479+
assert exc_info.group_contains(RuntimeError, recursive=True)
480+
481+
def test_doesnt_contain_exception_type_nonrecursive(self) -> None:
482+
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
483+
with pytest.raises(ExceptionGroup) as exc_info:
484+
raise exc_group
485+
assert not exc_info.group_contains(RuntimeError)
486+
487+
def test_contains_exception_match_recursive(self) -> None:
488+
exc_group = ExceptionGroup(
489+
"", [ExceptionGroup("", [RuntimeError("exception message")])]
490+
)
491+
with pytest.raises(ExceptionGroup) as exc_info:
492+
raise exc_group
493+
assert exc_info.group_contains(
494+
RuntimeError, match=r"^exception message$", recursive=True
495+
)
496+
497+
def test_doesnt_contain_exception_match_nonrecursive(self) -> None:
498+
exc_group = ExceptionGroup(
499+
"", [ExceptionGroup("", [RuntimeError("message that will not match")])]
500+
)
501+
with pytest.raises(ExceptionGroup) as exc_info:
502+
raise exc_group
503+
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
504+
505+
447506
class TestFormattedExcinfo:
448507
@pytest.fixture
449508
def importasmod(self, tmp_path: Path, _sys_snapshot):

0 commit comments

Comments
 (0)