2323import types
2424import unittest
2525from collections .abc import Callable , Iterator
26- from typing import TYPE_CHECKING , NoReturn , ParamSpec , TypeVar , cast , overload
26+ from typing import TYPE_CHECKING , Generic , NoReturn , ParamSpec , TypeVar , cast , overload
2727from unittest .case import SkipTest
2828
2929T = TypeVar ("T" )
3030U = TypeVar ("U" )
31+ _E = TypeVar ("_E" , bound = BaseException )
32+ _E2 = TypeVar ("_E2" , bound = BaseException )
33+ _E3 = TypeVar ("_E3" , bound = BaseException )
3134
3235# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules
3336from testtools import content
@@ -513,17 +516,38 @@ def assertRaises(
513516 @overload # type: ignore[override]
514517 def assertRaises (
515518 self ,
516- expected_exception : type [BaseException ] | tuple [ type [ BaseException ] ],
519+ expected_exception : type [_E ],
517520 callable : None = ...,
518- ) -> "_AssertRaisesContext" : ...
521+ ) -> "_AssertRaisesContext[_E] " : ...
519522
520- def assertRaises ( # type: ignore[override]
523+ @overload # type: ignore[override]
524+ def assertRaises (
521525 self ,
522- expected_exception : type [BaseException ] | tuple [type [BaseException ]],
526+ expected_exception : tuple [type [_E ], type [_E2 ]],
527+ callable : None = ...,
528+ ) -> "_AssertRaisesContext[_E | _E2]" : ...
529+
530+ @overload # type: ignore[override]
531+ def assertRaises (
532+ self ,
533+ expected_exception : tuple [type [_E ], type [_E2 ], type [_E3 ]],
534+ callable : None = ...,
535+ ) -> "_AssertRaisesContext[_E | _E2 | _E3]" : ...
536+
537+ @overload # type: ignore[override]
538+ def assertRaises (
539+ self ,
540+ expected_exception : tuple [type [BaseException ], ...],
541+ callable : None = ...,
542+ ) -> "_AssertRaisesContext[BaseException]" : ...
543+
544+ def assertRaises ( # type: ignore[override, misc]
545+ self ,
546+ expected_exception : type [BaseException ] | tuple [type [BaseException ], ...],
523547 callable : Callable [_P , _R ] | None = None ,
524548 * args : _P .args ,
525549 ** kwargs : _P .kwargs ,
526- ) -> "_AssertRaisesContext | BaseException" :
550+ ) -> "_AssertRaisesContext[BaseException] | BaseException" :
527551 """Fail unless an exception of class expected_exception is thrown
528552 by callable when invoked with arguments args and keyword
529553 arguments kwargs. If a different type of exception is
@@ -1185,7 +1209,7 @@ def decorator(test_item: _F) -> _F:
11851209 if not isinstance (test_item , class_types ):
11861210
11871211 @functools .wraps (test_item )
1188- def skip_wrapper (* args : object , ** kwargs : object ) -> None :
1212+ def skip_wrapper (* args : object , ** kwargs : object ) -> NoReturn :
11891213 raise TestCase .skipException (reason )
11901214
11911215 test_item = cast (_F , skip_wrapper )
@@ -1223,15 +1247,15 @@ def _id(obj: _F) -> _F:
12231247 return _id
12241248
12251249
1226- class _AssertRaisesContext :
1250+ class _AssertRaisesContext ( Generic [ _E ]) :
12271251 """A context manager to handle expected exceptions for assertRaises.
12281252
12291253 This provides compatibility with unittest's assertRaises context manager.
12301254 """
12311255
12321256 def __init__ (
12331257 self ,
1234- expected : type [BaseException ] | tuple [type [BaseException ]],
1258+ expected : type [_E ] | tuple [type [BaseException ], ... ],
12351259 test_case : TestCase ,
12361260 msg : str | None = None ,
12371261 ) -> None :
@@ -1244,7 +1268,7 @@ def __init__(
12441268 self .expected = expected
12451269 self .test_case = test_case
12461270 self .msg = msg
1247- self .exception : BaseException | None = None
1271+ self .exception : _E | None = None
12481272
12491273 def __enter__ (self ) -> "Self" :
12501274 return self
@@ -1274,7 +1298,7 @@ def __exit__(
12741298 # let unexpected exceptions pass through
12751299 return False
12761300 # store exception for later retrieval
1277- self .exception = exc_value
1301+ self .exception = cast ( _E , exc_value )
12781302 return True
12791303
12801304
@@ -1341,7 +1365,7 @@ def __exit__(
13411365 return True
13421366
13431367
1344- class Nullary :
1368+ class Nullary ( Generic [ _R ]) :
13451369 """Turn a callable into a nullary callable.
13461370
13471371 The advantage of this over ``lambda: f(*args, **kwargs)`` is that it
@@ -1358,7 +1382,7 @@ def __init__(
13581382 self ._args = args
13591383 self ._kwargs = kwargs
13601384
1361- def __call__ (self ) -> object :
1385+ def __call__ (self ) -> _R :
13621386 return self ._callable_object (* self ._args , ** self ._kwargs )
13631387
13641388 def __repr__ (self ) -> str :
0 commit comments