1212from numbers import Complex
1313import pprint
1414import re
15+ import sys
1516from types import TracebackType
1617from typing import Any
1718from typing import cast
1819from typing import final
20+ from typing import get_args
21+ from typing import get_origin
1922from typing import overload
2023from typing import TYPE_CHECKING
2124from typing import TypeVar
2427from _pytest .outcomes import fail
2528
2629
30+ if sys .version_info < (3 , 11 ):
31+ from exceptiongroup import BaseExceptionGroup
32+ from exceptiongroup import ExceptionGroup
33+
2734if TYPE_CHECKING :
2835 from numpy import ndarray
2936
@@ -954,15 +961,45 @@ def raises(
954961 f"Raising exceptions is already understood as failing the test, so you don't need "
955962 f"any special code to say 'this should never raise an exception'."
956963 )
964+
965+ expected_exceptions : tuple [type [E ], ...]
966+ origin_exc : type [E ] | None = get_origin (expected_exception )
957967 if isinstance (expected_exception , type ):
958- expected_exceptions : tuple [type [E ], ...] = (expected_exception ,)
968+ expected_exceptions = (expected_exception ,)
969+ elif origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
970+ expected_exceptions = (cast (type [E ], expected_exception ),)
959971 else :
960972 expected_exceptions = expected_exception
961- for exc in expected_exceptions :
962- if not isinstance (exc , type ) or not issubclass (exc , BaseException ):
973+
974+ def validate_exc (exc : type [E ]) -> type [E ]:
975+ __tracebackhide__ = True
976+ origin_exc : type [E ] | None = get_origin (exc )
977+ if origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
978+ exc_type = get_args (exc )[0 ]
979+ if (
980+ issubclass (origin_exc , ExceptionGroup ) and exc_type in (Exception , Any )
981+ ) or (
982+ issubclass (origin_exc , BaseExceptionGroup )
983+ and exc_type in (BaseException , Any )
984+ ):
985+ return cast (type [E ], origin_exc )
986+ else :
987+ raise ValueError (
988+ f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
989+ f"are accepted as generic types but got `{ exc } `. "
990+ f"As `raises` will catch all instances of the specified group regardless of the "
991+ f"generic argument specific nested exceptions has to be checked "
992+ f"with `ExceptionInfo.group_contains()`"
993+ )
994+
995+ elif not isinstance (exc , type ) or not issubclass (exc , BaseException ):
963996 msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
964997 not_a = exc .__name__ if isinstance (exc , type ) else type (exc ).__name__
965998 raise TypeError (msg .format (not_a ))
999+ else :
1000+ return exc
1001+
1002+ expected_exceptions = tuple (validate_exc (exc ) for exc in expected_exceptions )
9661003
9671004 message = f"DID NOT RAISE { expected_exception } "
9681005
@@ -973,14 +1010,14 @@ def raises(
9731010 msg += ", " .join (sorted (kwargs ))
9741011 msg += "\n Use context-manager form instead?"
9751012 raise TypeError (msg )
976- return RaisesContext (expected_exception , message , match )
1013+ return RaisesContext (expected_exceptions , message , match )
9771014 else :
9781015 func = args [0 ]
9791016 if not callable (func ):
9801017 raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
9811018 try :
9821019 func (* args [1 :], ** kwargs )
983- except expected_exception as e :
1020+ except expected_exceptions as e :
9841021 return _pytest ._code .ExceptionInfo .from_exception (e )
9851022 fail (message )
9861023
0 commit comments