Skip to content

Commit e6ec2fc

Browse files
authored
Merge pull request testing-cabal#587 from stephenfin/typing
More type improvements
2 parents a2c63cc + ddf9764 commit e6ec2fc

4 files changed

Lines changed: 48 additions & 19 deletions

File tree

testtools/matchers/_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import operator
2020
import re
21-
from collections.abc import Callable
21+
from collections.abc import Callable, Sized
2222
from pprint import pformat
2323
from typing import Any, Generic, TypeVar
2424

@@ -459,7 +459,7 @@ def match(self, value: str) -> Mismatch | None:
459459
return None
460460

461461

462-
def has_len(x: Any, y: int) -> bool:
462+
def has_len(x: Sized, y: int) -> bool:
463463
return len(x) == y
464464

465465

testtools/monkey.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
"""Helpers for monkey-patching Python code."""
44

55
from collections.abc import Callable
6-
from typing import Any
6+
from typing import ParamSpec, TypeVar
7+
8+
_P = ParamSpec("_P")
9+
_R = TypeVar("_R")
710

811
__all__ = [
912
"MonkeyPatcher",
@@ -71,7 +74,9 @@ def restore(self) -> None:
7174
else:
7275
setattr(obj, name, value)
7376

74-
def run_with_patches(self, f: Callable[..., Any], *args: Any, **kw: Any) -> Any:
77+
def run_with_patches(
78+
self, f: Callable[_P, _R], *args: _P.args, **kw: _P.kwargs
79+
) -> _R:
7580
"""Run 'f' with the given args and kwargs with all patches applied.
7681
7782
Restores all objects to their original state when finished.

testtools/runtest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import sys
1111
from collections.abc import Callable
12-
from typing import TYPE_CHECKING, Any
12+
from typing import TYPE_CHECKING, Any, NoReturn
1313

1414
from testtools.testresult import (
1515
ExcInfo,
@@ -252,7 +252,7 @@ def _got_user_exception(
252252
return self.exception_caught
253253

254254

255-
def _raise_force_fail_error() -> None:
255+
def _raise_force_fail_error() -> NoReturn:
256256
raise AssertionError("Forced Test Failure")
257257

258258

testtools/testcase.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@
2323
import types
2424
import unittest
2525
from collections.abc import Callable, Iterator
26-
from typing import TYPE_CHECKING, NoReturn, ParamSpec, TypeVar, cast, overload
26+
from typing import TYPE_CHECKING, Generic, NoReturn, ParamSpec, TypeVar, cast, overload
2727
from unittest.case import SkipTest
2828

2929
T = TypeVar("T")
3030
U = TypeVar("U")
31+
_E = TypeVar("_E", bound=BaseException)
32+
_E2 = TypeVar("_E2", bound=BaseException)
33+
_E3 = TypeVar("_E3", bound=BaseException)
3134

3235
# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules
3336
from testtools import content
@@ -513,17 +516,38 @@ def assertRaises(
513516
@overload # type: ignore[override]
514517
def assertRaises(
515518
self,
516-
expected_exception: type[BaseException] | tuple[type[BaseException]],
519+
expected_exception: type[_E],
517520
callable: None = ...,
518-
) -> "_AssertRaisesContext": ...
521+
) -> "_AssertRaisesContext[_E]": ...
519522

520-
def assertRaises( # type: ignore[override]
523+
@overload # type: ignore[override]
524+
def assertRaises(
521525
self,
522-
expected_exception: type[BaseException] | tuple[type[BaseException]],
526+
expected_exception: tuple[type[_E], type[_E2]],
527+
callable: None = ...,
528+
) -> "_AssertRaisesContext[_E | _E2]": ...
529+
530+
@overload # type: ignore[override]
531+
def assertRaises(
532+
self,
533+
expected_exception: tuple[type[_E], type[_E2], type[_E3]],
534+
callable: None = ...,
535+
) -> "_AssertRaisesContext[_E | _E2 | _E3]": ...
536+
537+
@overload # type: ignore[override]
538+
def assertRaises(
539+
self,
540+
expected_exception: tuple[type[BaseException], ...],
541+
callable: None = ...,
542+
) -> "_AssertRaisesContext[BaseException]": ...
543+
544+
def assertRaises( # type: ignore[override, misc]
545+
self,
546+
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
523547
callable: Callable[_P, _R] | None = None,
524548
*args: _P.args,
525549
**kwargs: _P.kwargs,
526-
) -> "_AssertRaisesContext | BaseException":
550+
) -> "_AssertRaisesContext[BaseException] | BaseException":
527551
"""Fail unless an exception of class expected_exception is thrown
528552
by callable when invoked with arguments args and keyword
529553
arguments kwargs. If a different type of exception is
@@ -1185,7 +1209,7 @@ def decorator(test_item: _F) -> _F:
11851209
if not isinstance(test_item, class_types):
11861210

11871211
@functools.wraps(test_item)
1188-
def skip_wrapper(*args: object, **kwargs: object) -> None:
1212+
def skip_wrapper(*args: object, **kwargs: object) -> NoReturn:
11891213
raise TestCase.skipException(reason)
11901214

11911215
test_item = cast(_F, skip_wrapper)
@@ -1223,15 +1247,15 @@ def _id(obj: _F) -> _F:
12231247
return _id
12241248

12251249

1226-
class _AssertRaisesContext:
1250+
class _AssertRaisesContext(Generic[_E]):
12271251
"""A context manager to handle expected exceptions for assertRaises.
12281252
12291253
This provides compatibility with unittest's assertRaises context manager.
12301254
"""
12311255

12321256
def __init__(
12331257
self,
1234-
expected: type[BaseException] | tuple[type[BaseException]],
1258+
expected: type[_E] | tuple[type[BaseException], ...],
12351259
test_case: TestCase,
12361260
msg: str | None = None,
12371261
) -> None:
@@ -1244,7 +1268,7 @@ def __init__(
12441268
self.expected = expected
12451269
self.test_case = test_case
12461270
self.msg = msg
1247-
self.exception: BaseException | None = None
1271+
self.exception: _E | None = None
12481272

12491273
def __enter__(self) -> "Self":
12501274
return self
@@ -1274,7 +1298,7 @@ def __exit__(
12741298
# let unexpected exceptions pass through
12751299
return False
12761300
# store exception for later retrieval
1277-
self.exception = exc_value
1301+
self.exception = cast(_E, exc_value)
12781302
return True
12791303

12801304

@@ -1341,7 +1365,7 @@ def __exit__(
13411365
return True
13421366

13431367

1344-
class Nullary:
1368+
class Nullary(Generic[_R]):
13451369
"""Turn a callable into a nullary callable.
13461370
13471371
The advantage of this over ``lambda: f(*args, **kwargs)`` is that it
@@ -1358,7 +1382,7 @@ def __init__(
13581382
self._args = args
13591383
self._kwargs = kwargs
13601384

1361-
def __call__(self) -> object:
1385+
def __call__(self) -> _R:
13621386
return self._callable_object(*self._args, **self._kwargs)
13631387

13641388
def __repr__(self) -> str:

0 commit comments

Comments
 (0)