|
10 | 10 | from types import TracebackType
|
11 | 11 | from typing import Any
|
12 | 12 | from typing import Callable
|
| 13 | +from typing import cast |
| 14 | +from typing import Generic |
13 | 15 | from typing import Optional
|
14 | 16 | from typing import overload
|
15 | 17 | from typing import Pattern
|
16 | 18 | from typing import Tuple
|
| 19 | +from typing import TypeVar |
17 | 20 | from typing import Union
|
18 | 21 |
|
19 | 22 | from more_itertools.more import always_iterable
|
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
|
537 | 540 |
|
538 | 541 | # builtin pytest.raises helper
|
539 | 542 |
|
| 543 | +_E = TypeVar("_E", bound=BaseException) |
| 544 | + |
540 | 545 |
|
541 | 546 | @overload
|
542 | 547 | def raises(
|
543 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 548 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
544 | 549 | *,
|
545 | 550 | match: Optional[Union[str, Pattern]] = ...
|
546 |
| -) -> "RaisesContext": |
| 551 | +) -> "RaisesContext[_E]": |
547 | 552 | ... # pragma: no cover
|
548 | 553 |
|
549 | 554 |
|
550 | 555 | @overload
|
551 | 556 | def raises(
|
552 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 557 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
553 | 558 | func: Callable,
|
554 | 559 | *args: Any,
|
555 | 560 | match: Optional[str] = ...,
|
556 | 561 | **kwargs: Any
|
557 |
| -) -> Optional[_pytest._code.ExceptionInfo]: |
| 562 | +) -> Optional[_pytest._code.ExceptionInfo[_E]]: |
558 | 563 | ... # pragma: no cover
|
559 | 564 |
|
560 | 565 |
|
561 | 566 | def raises(
|
562 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 567 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
563 | 568 | *args: Any,
|
564 | 569 | match: Optional[Union[str, Pattern]] = None,
|
565 | 570 | **kwargs: Any
|
566 |
| -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: |
| 571 | +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: |
567 | 572 | r"""
|
568 | 573 | Assert that a code block/function call raises ``expected_exception``
|
569 | 574 | or raise a failure exception otherwise.
|
@@ -703,28 +708,30 @@ def raises(
|
703 | 708 | try:
|
704 | 709 | func(*args[1:], **kwargs)
|
705 | 710 | except expected_exception:
|
706 |
| - return _pytest._code.ExceptionInfo.from_current() |
| 711 | + # Cast to narrow the type to expected_exception (_E). |
| 712 | + return cast( |
| 713 | + _pytest._code.ExceptionInfo[_E], |
| 714 | + _pytest._code.ExceptionInfo.from_current(), |
| 715 | + ) |
707 | 716 | fail(message)
|
708 | 717 |
|
709 | 718 |
|
710 | 719 | raises.Exception = fail.Exception # type: ignore
|
711 | 720 |
|
712 | 721 |
|
713 |
| -class RaisesContext: |
| 722 | +class RaisesContext(Generic[_E]): |
714 | 723 | def __init__(
|
715 | 724 | self,
|
716 |
| - expected_exception: Union[ |
717 |
| - "Type[BaseException]", Tuple["Type[BaseException]", ...] |
718 |
| - ], |
| 725 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
719 | 726 | message: str,
|
720 | 727 | match_expr: Optional[Union[str, Pattern]] = None,
|
721 | 728 | ) -> None:
|
722 | 729 | self.expected_exception = expected_exception
|
723 | 730 | self.message = message
|
724 | 731 | self.match_expr = match_expr
|
725 |
| - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] |
| 732 | + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] |
726 | 733 |
|
727 |
| - def __enter__(self) -> _pytest._code.ExceptionInfo: |
| 734 | + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: |
728 | 735 | self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
729 | 736 | return self.excinfo
|
730 | 737 |
|
|
0 commit comments