Skip to content

Commit 14bf4cd

Browse files
committed
Make ExceptionInfo generic in the exception type
This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException.
1 parent 56dcc9e commit 14bf4cd

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

src/_pytest/_code/code.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from inspect import CO_VARKEYWORDS
77
from traceback import format_exception_only
88
from types import TracebackType
9+
from typing import Generic
910
from typing import Optional
1011
from typing import Pattern
1112
from typing import Tuple
13+
from typing import TypeVar
1214
from typing import Union
1315
from weakref import ref
1416

@@ -379,22 +381,25 @@ def recursionindex(self):
379381
)
380382

381383

384+
_E = TypeVar("_E", bound=BaseException)
385+
386+
382387
@attr.s(repr=False)
383-
class ExceptionInfo:
388+
class ExceptionInfo(Generic[_E]):
384389
""" wraps sys.exc_info() objects and offers
385390
help for navigating the traceback.
386391
"""
387392

388393
_assert_start_repr = "AssertionError('assert "
389394

390-
_excinfo = attr.ib(
391-
type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
392-
)
395+
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
393396
_striptext = attr.ib(type=str, default="")
394397
_traceback = attr.ib(type=Optional[Traceback], default=None)
395398

396399
@classmethod
397-
def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
400+
def from_current(
401+
cls, exprinfo: Optional[str] = None
402+
) -> "ExceptionInfo[BaseException]":
398403
"""returns an ExceptionInfo matching the current traceback
399404
400405
.. warning::
@@ -422,21 +427,21 @@ def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
422427
return cls(tup, _striptext)
423428

424429
@classmethod
425-
def for_later(cls) -> "ExceptionInfo":
430+
def for_later(cls) -> "ExceptionInfo[_E]":
426431
"""return an unfilled ExceptionInfo
427432
"""
428433
return cls(None)
429434

430435
@property
431-
def type(self) -> "Type[BaseException]":
436+
def type(self) -> "Type[_E]":
432437
"""the exception class"""
433438
assert (
434439
self._excinfo is not None
435440
), ".type can only be used after the context manager exits"
436441
return self._excinfo[0]
437442

438443
@property
439-
def value(self) -> BaseException:
444+
def value(self) -> _E:
440445
"""the exception value"""
441446
assert (
442447
self._excinfo is not None

src/_pytest/python_api.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from types import TracebackType
1111
from typing import Any
1212
from typing import Callable
13+
from typing import cast
14+
from typing import Generic
1315
from typing import Optional
1416
from typing import overload
1517
from typing import Pattern
1618
from typing import Tuple
19+
from typing import TypeVar
1720
from typing import Union
1821

1922
from more_itertools.more import always_iterable
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
537540

538541
# builtin pytest.raises helper
539542

543+
_E = TypeVar("_E", bound=BaseException)
544+
540545

541546
@overload
542547
def raises(
543-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
548+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
544549
*,
545550
match: Optional[Union[str, Pattern]] = ...
546-
) -> "RaisesContext":
551+
) -> "RaisesContext[_E]":
547552
... # pragma: no cover
548553

549554

550555
@overload
551556
def raises(
552-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
557+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
553558
func: Callable,
554559
*args: Any,
555560
match: Optional[str] = ...,
556561
**kwargs: Any
557-
) -> Optional[_pytest._code.ExceptionInfo]:
562+
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
558563
... # pragma: no cover
559564

560565

561566
def raises(
562-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
567+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
563568
*args: Any,
564569
match: Optional[Union[str, Pattern]] = None,
565570
**kwargs: Any
566-
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
571+
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
567572
r"""
568573
Assert that a code block/function call raises ``expected_exception``
569574
or raise a failure exception otherwise.
@@ -703,28 +708,30 @@ def raises(
703708
try:
704709
func(*args[1:], **kwargs)
705710
except expected_exception:
706-
return _pytest._code.ExceptionInfo.from_current()
711+
# Cast to narrow the type to expected_exception (_E).
712+
return cast(
713+
_pytest._code.ExceptionInfo[_E],
714+
_pytest._code.ExceptionInfo.from_current(),
715+
)
707716
fail(message)
708717

709718

710719
raises.Exception = fail.Exception # type: ignore
711720

712721

713-
class RaisesContext:
722+
class RaisesContext(Generic[_E]):
714723
def __init__(
715724
self,
716-
expected_exception: Union[
717-
"Type[BaseException]", Tuple["Type[BaseException]", ...]
718-
],
725+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
719726
message: str,
720727
match_expr: Optional[Union[str, Pattern]] = None,
721728
) -> None:
722729
self.expected_exception = expected_exception
723730
self.message = message
724731
self.match_expr = match_expr
725-
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
732+
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
726733

727-
def __enter__(self) -> _pytest._code.ExceptionInfo:
734+
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
728735
self.excinfo = _pytest._code.ExceptionInfo.for_later()
729736
return self.excinfo
730737

0 commit comments

Comments
 (0)