1212from numbers import Complex
1313import pprint
1414import re
15+ import sys
1516from types import TracebackType
1617from typing import Any
1718from typing import cast
1819from typing import final
20+ from typing import get_args
21+ from typing import get_origin
1922from typing import overload
2023from typing import TYPE_CHECKING
2124from typing import TypeVar
2427from _pytest .outcomes import fail
2528
2629
30+ if sys .version_info < (3 , 11 ):
31+ from exceptiongroup import BaseExceptionGroup
32+ from exceptiongroup import ExceptionGroup
33+
2734if TYPE_CHECKING :
2835 from numpy import ndarray
2936
@@ -954,15 +961,43 @@ def raises(
954961 f"Raising exceptions is already understood as failing the test, so you don't need "
955962 f"any special code to say 'this should never raise an exception'."
956963 )
964+
965+ expected_exceptions : tuple [type [E ], ...]
966+ origin_exc : type [E ] | None = get_origin (expected_exception )
957967 if isinstance (expected_exception , type ):
958- expected_exceptions : tuple [type [E ], ...] = (expected_exception ,)
968+ expected_exceptions = (expected_exception ,)
969+ elif origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
970+ expected_exceptions = (cast (type [E ], expected_exception ),)
959971 else :
960972 expected_exceptions = expected_exception
961- for exc in expected_exceptions :
962- if not isinstance (exc , type ) or not issubclass (exc , BaseException ):
973+
974+ def validate_exc (exc : type [E ]) -> type [E ]:
975+ origin_exc : type [E ] | None = get_origin (exc )
976+ if origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
977+ exc_type = get_args (exc )[0 ]
978+ if issubclass (origin_exc , ExceptionGroup ) and exc_type is Exception :
979+ return cast (type [E ], origin_exc )
980+ elif (
981+ issubclass (origin_exc , BaseExceptionGroup ) and exc_type is BaseException
982+ ):
983+ return cast (type [E ], origin_exc )
984+ else :
985+ raise ValueError (
986+ f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
987+ f"are accepted as generic types but got `{ exc } `. "
988+ f"As `raises` will catch all instances of the specified group regardless of the "
989+ f"generic argument specific nested exceptions has to be checked "
990+ f"with `ExceptionInfo.group_contains()`"
991+ )
992+
993+ elif not isinstance (exc , type ) or not issubclass (exc , BaseException ):
963994 msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
964995 not_a = exc .__name__ if isinstance (exc , type ) else type (exc ).__name__
965996 raise TypeError (msg .format (not_a ))
997+ else :
998+ return exc
999+
1000+ expected_exceptions = tuple (validate_exc (exc ) for exc in expected_exceptions )
9661001
9671002 message = f"DID NOT RAISE { expected_exception } "
9681003
@@ -973,14 +1008,14 @@ def raises(
9731008 msg += ", " .join (sorted (kwargs ))
9741009 msg += "\n Use context-manager form instead?"
9751010 raise TypeError (msg )
976- return RaisesContext (expected_exception , message , match )
1011+ return RaisesContext (expected_exceptions , message , match )
9771012 else :
9781013 func = args [0 ]
9791014 if not callable (func ):
9801015 raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
9811016 try :
9821017 func (* args [1 :], ** kwargs )
983- except expected_exception as e :
1018+ except expected_exceptions as e :
9841019 return _pytest ._code .ExceptionInfo .from_exception (e )
9851020 fail (message )
9861021
0 commit comments