@@ -3802,25 +3802,44 @@ def test_traceback_header(self):
38023802 exc = traceback .TracebackException (Exception , Exception ("haven" ), None )
38033803 self .assertEqual (list (exc .format ()), ["Exception: haven\n " ])
38043804
3805- def test_exception_punctuation_handling_with_suggestions (self ):
3805+ def test_name_error_punctuation_with_suggestions (self ):
38063806 def raise_mssage (message , name ):
38073807 try :
38083808 raise NameError (message , name = name )
3809- except Exception as e :
3810- return traceback .TracebackException .from_exception (e )._str
3809+ except NameError as e :
3810+ exc = traceback .TracebackException .from_exception (e )
3811+ return list (exc .format ())[- 1 ]
38113812
38123813 test_cases = [
3813- ("Error ." , "time" , "Error . Did you forget to import 'time'?" ),
3814- ("Error ?" , "time" , "Error ? Did you forget to import 'time'?" ),
3815- ("Error !" , "time" , "Error ! Did you forget to import 'time'?" ),
3816- ("Error " , "time" , "Error . Did you forget to import 'time'?" ),
3817- ("Error " , "foo123" , "Error " ),
3814+ ("a ." , "time" , "NameError: a . Did you forget to import 'time'?\n " ),
3815+ ("b ?" , "time" , "NameError: b ? Did you forget to import 'time'?\n " ),
3816+ ("c !" , "time" , "NameError: c ! Did you forget to import 'time'?\n " ),
3817+ ("d " , "time" , "NameError: d . Did you forget to import 'time'?\n " ),
3818+ ("e " , "foo123" , "NameError: e \n " ),
38183819 ]
3819- for puctuation , name , expected in test_cases :
3820- with self .subTest (puctuation = puctuation ):
3821- messsage = raise_mssage (puctuation , name )
3820+ for message , name , expected in test_cases :
3821+ with self .subTest (message = message ):
3822+ messsage = raise_mssage (message , name )
38223823 self .assertEqual (messsage , expected )
38233824
3825+ def test_import_error_punctuation_handling_with_suggestions (self ):
3826+ def raise_mssage (message ):
3827+ try :
3828+ raise ImportError (message , name = "math" , name_from = "sinq" )
3829+ except ImportError as e :
3830+ exc = traceback .TracebackException .from_exception (e )
3831+ return list (exc .format ())[- 1 ]
3832+
3833+ test_cases = [
3834+ ("a." , "ImportError: a. Did you mean: 'sin'?\n " ),
3835+ ("b?" , "ImportError: b? Did you mean: 'sin'?\n " ),
3836+ ("c!" , "ImportError: c! Did you mean: 'sin'?\n " ),
3837+ ("d" , "ImportError: d. Did you mean: 'sin'?\n " ),
3838+ ]
3839+ for message , expected in test_cases :
3840+ with self .subTest (message = message ):
3841+ messsage = raise_mssage (message )
3842+ self .assertEqual (messsage , expected )
38243843
38253844 @requires_debug_ranges ()
38263845 def test_print (self ):
0 commit comments