Skip to content

Commit 56dcc9e

Browse files
committed
Type-annotate pytest.raises
1 parent 55a570e commit 56dcc9e

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

src/_pytest/python_api.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
from decimal import Decimal
88
from itertools import filterfalse
99
from numbers import Number
10+
from types import TracebackType
11+
from typing import Any
12+
from typing import Callable
13+
from typing import Optional
14+
from typing import overload
15+
from typing import Pattern
16+
from typing import Tuple
1017
from typing import Union
1118

1219
from more_itertools.more import always_iterable
@@ -15,6 +22,9 @@
1522
from _pytest.compat import STRING_TYPES
1623
from _pytest.outcomes import fail
1724

25+
if False: # TYPE_CHECKING
26+
from typing import Type # noqa: F401 (used in type string)
27+
1828
BASE_TYPE = (type, STRING_TYPES)
1929

2030

@@ -528,7 +538,32 @@ def _is_numpy_array(obj):
528538
# builtin pytest.raises helper
529539

530540

531-
def raises(expected_exception, *args, match=None, **kwargs):
541+
@overload
542+
def raises(
543+
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
544+
*,
545+
match: Optional[Union[str, Pattern]] = ...
546+
) -> "RaisesContext":
547+
... # pragma: no cover
548+
549+
550+
@overload
551+
def raises(
552+
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
553+
func: Callable,
554+
*args: Any,
555+
match: Optional[str] = ...,
556+
**kwargs: Any
557+
) -> Optional[_pytest._code.ExceptionInfo]:
558+
... # pragma: no cover
559+
560+
561+
def raises(
562+
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
563+
*args: Any,
564+
match: Optional[Union[str, Pattern]] = None,
565+
**kwargs: Any
566+
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
532567
r"""
533568
Assert that a code block/function call raises ``expected_exception``
534569
or raise a failure exception otherwise.
@@ -676,21 +711,35 @@ def raises(expected_exception, *args, match=None, **kwargs):
676711

677712

678713
class RaisesContext:
679-
def __init__(self, expected_exception, message, match_expr):
714+
def __init__(
715+
self,
716+
expected_exception: Union[
717+
"Type[BaseException]", Tuple["Type[BaseException]", ...]
718+
],
719+
message: str,
720+
match_expr: Optional[Union[str, Pattern]] = None,
721+
) -> None:
680722
self.expected_exception = expected_exception
681723
self.message = message
682724
self.match_expr = match_expr
683-
self.excinfo = None
725+
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
684726

685-
def __enter__(self):
727+
def __enter__(self) -> _pytest._code.ExceptionInfo:
686728
self.excinfo = _pytest._code.ExceptionInfo.for_later()
687729
return self.excinfo
688730

689-
def __exit__(self, *tp):
731+
def __exit__(
732+
self,
733+
exc_type: Optional["Type[BaseException]"],
734+
exc_val: Optional[BaseException],
735+
exc_tb: Optional[TracebackType],
736+
) -> bool:
690737
__tracebackhide__ = True
691-
if tp[0] is None:
738+
if exc_type is None:
692739
fail(self.message)
693-
self.excinfo.__init__(tp)
740+
assert self.excinfo is not None
741+
# Type ignored because mypy doesn't like calling __init__ directly like this.
742+
self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore
694743
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
695744
if self.match_expr is not None and suppress_exception:
696745
self.excinfo.match(self.match_expr)

0 commit comments

Comments
 (0)