Skip to content

Commit 11f1f79

Browse files
committed
Allow creating ExceptionInfo from existing exc_info for better typing
This way the ExceptionInfo generic parameter can be inferred from the passed-in exc_info. See for example the replaced cast().
1 parent 3f1fb62 commit 11f1f79

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

src/_pytest/_code/code.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,12 @@ class ExceptionInfo(Generic[_E]):
397397
_traceback = attr.ib(type=Optional[Traceback], default=None)
398398

399399
@classmethod
400-
def from_current(
401-
cls, exprinfo: Optional[str] = None
402-
) -> "ExceptionInfo[BaseException]":
403-
"""returns an ExceptionInfo matching the current traceback
400+
def from_exc_info(
401+
cls,
402+
exc_info: Tuple["Type[_E]", "_E", TracebackType],
403+
exprinfo: Optional[str] = None,
404+
) -> "ExceptionInfo[_E]":
405+
"""returns an ExceptionInfo for an existing exc_info tuple.
404406
405407
.. warning::
406408
@@ -411,20 +413,37 @@ def from_current(
411413
strip ``AssertionError`` from the output, defaults
412414
to the exception message/``__str__()``
413415
"""
414-
tup_ = sys.exc_info()
415-
assert tup_[0] is not None, "no current exception"
416-
assert tup_[1] is not None, "no current exception"
417-
assert tup_[2] is not None, "no current exception"
418-
tup = (tup_[0], tup_[1], tup_[2])
419416
_striptext = ""
420-
if exprinfo is None and isinstance(tup[1], AssertionError):
421-
exprinfo = getattr(tup[1], "msg", None)
417+
if exprinfo is None and isinstance(exc_info[1], AssertionError):
418+
exprinfo = getattr(exc_info[1], "msg", None)
422419
if exprinfo is None:
423-
exprinfo = saferepr(tup[1])
420+
exprinfo = saferepr(exc_info[1])
424421
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
425422
_striptext = "AssertionError: "
426423

427-
return cls(tup, _striptext)
424+
return cls(exc_info, _striptext)
425+
426+
@classmethod
427+
def from_current(
428+
cls, exprinfo: Optional[str] = None
429+
) -> "ExceptionInfo[BaseException]":
430+
"""returns an ExceptionInfo matching the current traceback
431+
432+
.. warning::
433+
434+
Experimental API
435+
436+
437+
:param exprinfo: a text string helping to determine if we should
438+
strip ``AssertionError`` from the output, defaults
439+
to the exception message/``__str__()``
440+
"""
441+
tup = sys.exc_info()
442+
assert tup[0] is not None, "no current exception"
443+
assert tup[1] is not None, "no current exception"
444+
assert tup[2] is not None, "no current exception"
445+
exc_info = (tup[0], tup[1], tup[2])
446+
return cls.from_exc_info(exc_info)
428447

429448
@classmethod
430449
def for_later(cls) -> "ExceptionInfo[_E]":

src/_pytest/python_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,11 @@ def raises(
707707
)
708708
try:
709709
func(*args[1:], **kwargs)
710-
except expected_exception:
711-
# Cast to narrow the type to expected_exception (_E).
712-
return cast(
713-
_pytest._code.ExceptionInfo[_E],
714-
_pytest._code.ExceptionInfo.from_current(),
710+
except expected_exception as e:
711+
# We just caught the exception - there is a traceback.
712+
assert e.__traceback__ is not None
713+
return _pytest._code.ExceptionInfo.from_exc_info(
714+
(type(e), e, e.__traceback__)
715715
)
716716
fail(message)
717717

testing/code/test_excinfo.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,22 @@ def get_write_msg(self, idx):
5858
fullwidth = 80
5959

6060

61-
def test_excinfo_simple():
61+
def test_excinfo_simple() -> None:
6262
try:
6363
raise ValueError
6464
except ValueError:
6565
info = _pytest._code.ExceptionInfo.from_current()
6666
assert info.type == ValueError
6767

6868

69+
def test_excinfo_from_exc_info_simple():
70+
try:
71+
raise ValueError
72+
except ValueError as e:
73+
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
74+
assert info.type == ValueError
75+
76+
6977
def test_excinfo_getstatement():
7078
def g():
7179
raise ValueError

0 commit comments

Comments
 (0)