2020import unittest
2121import warnings
2222from collections .abc import Callable , Generator
23- from contextlib import AbstractContextManager
2423from logging import Logger
2524from pstats import Stats
2625from types import FrameType
27- from typing import Any , TypeVar , Union
26+ from typing import Any , cast , TypeVar , Union
2827
2928import numpy as np
3029import torch
@@ -60,9 +59,7 @@ def _get_tb_lines(tb: types.TracebackType) -> list[tuple[str, int, str]]:
6059 return res
6160
6261
63- # pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
64- # parameter.
65- class _AssertRaisesContextOn (unittest .case ._AssertRaisesContext ):
62+ class _AssertRaisesContextOn (unittest .case ._AssertRaisesContext [Exception ]):
6663 """
6764 Attributes:
6865 lineno: the line number on which the error occurred
@@ -89,12 +86,10 @@ def __init__(
8986 expected = expected , test_case = test_case , expected_regex = expected_regex
9087 )
9188
92- # pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
93- # inconsistently.
9489 def __exit__ (
9590 self ,
96- exc_type : type [Exception ] | None ,
97- exc_value : Exception | None ,
91+ exc_type : type [BaseException ] | None ,
92+ exc_value : BaseException | None ,
9893 tb : types .TracebackType | None ,
9994 ) -> bool :
10095 """This is called when the context closes. If an exception was raised
@@ -110,10 +105,8 @@ def __exit__(
110105 self .filename , self .lineno , _ = frames [0 ]
111106 lines = [line for _ , _ , line in frames ]
112107 if self ._expected_line is not None and self ._expected_line not in lines :
113- # pyre-ignore [16]: ... has no attribute `_raiseFailure`.
114- self ._raiseFailure (
115- f"{ self ._expected_line !r} was not found in the traceback: { lines !r} "
116- )
108+ msg = f"{ self ._expected_line !r} was not found in the traceback: { lines !r} "
109+ raise self .test_case .failureException (msg )
117110
118111 return True
119112
@@ -420,12 +413,10 @@ def assertRaisesOn(
420413 exc : type [Exception ],
421414 line : str | None = None ,
422415 regex : str | None = None ,
423- # pyre-ignore[24]: Generic type `AbstractContextManager`
424- # expects 2 type parameters, received 1.
425- ) -> AbstractContextManager [None ]:
416+ ) -> _AssertRaisesContextOn :
426417 """Assert that an exception is raised on a specific line."""
427418 context = _AssertRaisesContextOn (exc , self , line , regex )
428- return context .handle ("assertRaisesOn" , [], {})
419+ return cast ( _AssertRaisesContextOn , context .handle ("assertRaisesOn" , [], {}) )
429420
430421 def assertDictsAlmostEqual (
431422 self , a : dict [str , Any ], b : dict [str , Any ], consider_nans_equal : bool = False
@@ -532,30 +523,21 @@ def ax_long_test(cls, reason: str | None) -> Generator[None, None, None]:
532523 cls ._long_test_active_reason = None
533524
534525 # This list is taken from the python standard library
535- # pyre-fixme[4]: Attribute must be annotated.
536- failUnlessEqual = assertEquals = _deprecate (unittest .TestCase .assertEqual )
537- # pyre-fixme[4]: Attribute must be annotated.
538- failIfEqual = assertNotEquals = _deprecate (unittest .TestCase .assertNotEqual )
539- # pyre-fixme[4]: Attribute must be annotated.
540- failUnlessAlmostEqual = assertAlmostEquals = _deprecate (
541- unittest .TestCase .assertAlmostEqual
542- )
543- # pyre-fixme[4]: Attribute must be annotated.
544- failIfAlmostEqual = assertNotAlmostEquals = _deprecate (
545- unittest .TestCase .assertNotAlmostEqual
546- )
547- # pyre-fixme[4]: Attribute must be annotated.
548- failUnless = assert_ = _deprecate (unittest .TestCase .assertTrue )
549- # pyre-fixme[4]: Attribute must be annotated.
550- failUnlessRaises = _deprecate (unittest .TestCase .assertRaises )
551- # pyre-fixme[4]: Attribute must be annotated.
552- failIf = _deprecate (unittest .TestCase .assertFalse )
553- # pyre-fixme[4]: Attribute must be annotated.
554- assertRaisesRegexp = _deprecate (unittest .TestCase .assertRaisesRegex )
555- # pyre-fixme[4]: Attribute must be annotated.
556- assertRegexpMatches = _deprecate (unittest .TestCase .assertRegex )
557- # pyre-fixme[4]: Attribute must be annotated.
558- assertNotRegexpMatches = _deprecate (unittest .TestCase .assertNotRegex )
526+ failUnlessEqual : Callable = _deprecate (unittest .TestCase .assertEqual )
527+ assertEquals : Callable = failUnlessEqual
528+ failIfEqual : Callable = _deprecate (unittest .TestCase .assertNotEqual )
529+ assertNotEquals : Callable = failIfEqual
530+ failUnlessAlmostEqual : Callable = _deprecate (unittest .TestCase .assertAlmostEqual )
531+ assertAlmostEquals : Callable = failUnlessAlmostEqual
532+ failIfAlmostEqual : Callable = _deprecate (unittest .TestCase .assertNotAlmostEqual )
533+ assertNotAlmostEquals : Callable = failIfAlmostEqual
534+ failUnless : Callable = _deprecate (unittest .TestCase .assertTrue )
535+ assert_ : Callable = failUnless
536+ failUnlessRaises : Callable = _deprecate (unittest .TestCase .assertRaises )
537+ failIf : Callable = _deprecate (unittest .TestCase .assertFalse )
538+ assertRaisesRegexp : Callable = _deprecate (unittest .TestCase .assertRaisesRegex )
539+ assertRegexpMatches : Callable = _deprecate (unittest .TestCase .assertRegex )
540+ assertNotRegexpMatches : Callable = _deprecate (unittest .TestCase .assertNotRegex )
559541
560542 # Copied from BoTorch assertAllClose
561543 def assertAllClose (
0 commit comments