|
7 | 7 | from decimal import Decimal
|
8 | 8 | from itertools import filterfalse
|
9 | 9 | from numbers import Number
|
| 10 | +from types import TracebackType |
| 11 | +from typing import Any |
| 12 | +from typing import Callable |
| 13 | +from typing import Optional |
| 14 | +from typing import overload |
| 15 | +from typing import Pattern |
| 16 | +from typing import Tuple |
10 | 17 | from typing import Union
|
11 | 18 |
|
12 | 19 | from more_itertools.more import always_iterable
|
|
15 | 22 | from _pytest.compat import STRING_TYPES
|
16 | 23 | from _pytest.outcomes import fail
|
17 | 24 |
|
| 25 | +if False: # TYPE_CHECKING |
| 26 | + from typing import Type # noqa: F401 (used in type string) |
| 27 | + |
18 | 28 | BASE_TYPE = (type, STRING_TYPES)
|
19 | 29 |
|
20 | 30 |
|
@@ -528,7 +538,32 @@ def _is_numpy_array(obj):
|
528 | 538 | # builtin pytest.raises helper
|
529 | 539 |
|
530 | 540 |
|
531 |
| -def raises(expected_exception, *args, match=None, **kwargs): |
| 541 | +@overload |
| 542 | +def raises( |
| 543 | + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 544 | + *, |
| 545 | + match: Optional[Union[str, Pattern]] = ... |
| 546 | +) -> "RaisesContext": |
| 547 | + ... # pragma: no cover |
| 548 | + |
| 549 | + |
| 550 | +@overload |
| 551 | +def raises( |
| 552 | + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 553 | + func: Callable, |
| 554 | + *args: Any, |
| 555 | + match: Optional[str] = ..., |
| 556 | + **kwargs: Any |
| 557 | +) -> Optional[_pytest._code.ExceptionInfo]: |
| 558 | + ... # pragma: no cover |
| 559 | + |
| 560 | + |
| 561 | +def raises( |
| 562 | + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 563 | + *args: Any, |
| 564 | + match: Optional[Union[str, Pattern]] = None, |
| 565 | + **kwargs: Any |
| 566 | +) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: |
532 | 567 | r"""
|
533 | 568 | Assert that a code block/function call raises ``expected_exception``
|
534 | 569 | or raise a failure exception otherwise.
|
@@ -676,21 +711,35 @@ def raises(expected_exception, *args, match=None, **kwargs):
|
676 | 711 |
|
677 | 712 |
|
678 | 713 | class RaisesContext:
|
679 |
| - def __init__(self, expected_exception, message, match_expr): |
| 714 | + def __init__( |
| 715 | + self, |
| 716 | + expected_exception: Union[ |
| 717 | + "Type[BaseException]", Tuple["Type[BaseException]", ...] |
| 718 | + ], |
| 719 | + message: str, |
| 720 | + match_expr: Optional[Union[str, Pattern]] = None, |
| 721 | + ) -> None: |
680 | 722 | self.expected_exception = expected_exception
|
681 | 723 | self.message = message
|
682 | 724 | self.match_expr = match_expr
|
683 |
| - self.excinfo = None |
| 725 | + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] |
684 | 726 |
|
685 |
| - def __enter__(self): |
| 727 | + def __enter__(self) -> _pytest._code.ExceptionInfo: |
686 | 728 | self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
687 | 729 | return self.excinfo
|
688 | 730 |
|
689 |
| - def __exit__(self, *tp): |
| 731 | + def __exit__( |
| 732 | + self, |
| 733 | + exc_type: Optional["Type[BaseException]"], |
| 734 | + exc_val: Optional[BaseException], |
| 735 | + exc_tb: Optional[TracebackType], |
| 736 | + ) -> bool: |
690 | 737 | __tracebackhide__ = True
|
691 |
| - if tp[0] is None: |
| 738 | + if exc_type is None: |
692 | 739 | fail(self.message)
|
693 |
| - self.excinfo.__init__(tp) |
| 740 | + assert self.excinfo is not None |
| 741 | + # Type ignored because mypy doesn't like calling __init__ directly like this. |
| 742 | + self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore |
694 | 743 | suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
|
695 | 744 | if self.match_expr is not None and suppress_exception:
|
696 | 745 | self.excinfo.match(self.match_expr)
|
|
0 commit comments