@@ -1202,7 +1202,81 @@ def gen_lock_key():
12021202 return os .getpid () * 1000 + _lock_cnt
12031203
12041204
1205- def _typecheck (func , * , xfail = False ):
1205+ def _typecheck (cls , func , * , xfail = False ):
1206+ wrapped = inspect .unwrap (func )
1207+ is_async = inspect .iscoroutinefunction (wrapped )
1208+
1209+ def run (self ):
1210+ # We already ran the typechecker on everything, so now inspect
1211+ # the results for this test.
1212+
1213+ mypy_output = cls ._mypy_errors .get (func .__name__ , '' )
1214+ mypy_error = 'error:' in mypy_output
1215+
1216+ pyright_output = cls ._pyright_errors .get (func .__name__ , '' )
1217+ pyright_error = '- error:' in pyright_output
1218+
1219+ if mypy_error or pyright_error :
1220+ source_code = _get_file_code (func )
1221+ lines = source_code .split ("\n " )
1222+ pad_width = max (2 , len (str (len (lines ))))
1223+ source_code_numbered = "\n " .join (
1224+ f"{ i + 1 :0{pad_width }d} : { line } "
1225+ for i , line in enumerate (lines )
1226+ )
1227+
1228+ if mypy_error :
1229+ raise RuntimeError (
1230+ f"mypy check failed for { func .__name__ } "
1231+ f"\n \n test code:\n { source_code_numbered } "
1232+ f"\n \n mypy stdout:\n { mypy_output } "
1233+ )
1234+
1235+ if pyright_error :
1236+ raise RuntimeError (
1237+ f"pyright check failed for { func .__name__ } "
1238+ f"\n \n test code:\n { source_code_numbered } "
1239+ f"\n \n pyright stdout:\n { pyright_output } "
1240+ )
1241+
1242+ types = []
1243+ for line in mypy_output .split ("\n " ):
1244+ if m := re .match (r'.*Revealed type is "(?P<name>[^"]+)".*' , line ):
1245+ types .append (m .group ("name" ))
1246+
1247+ def reveal_type (_ , * , ncalls = [0 ], types = types ): # noqa: B006
1248+ ncalls [0 ] += 1
1249+ try :
1250+ return types [ncalls [0 ] - 1 ]
1251+ except IndexError :
1252+ return None
1253+
1254+ if is_async :
1255+ # `func` is the result of `TestCaseMeta.wrap()`, so we
1256+ # want to use the coroutine function that was there in
1257+ # the beginning, so let's use `wrapped`
1258+ wrapped .__globals__ ["reveal_type" ] = reveal_type
1259+ return wrapped (self )
1260+ else :
1261+ func .__globals__ ["reveal_type" ] = reveal_type
1262+ return func (self )
1263+
1264+ if is_async :
1265+
1266+ @functools .wraps (wrapped )
1267+ async def runner (self , run = run ):
1268+ coro = run (self )
1269+ await coro
1270+
1271+ rewrapped = TestCaseMeta .wrap (runner )
1272+ return unittest .expectedFailure (rewrapped ) if xfail else rewrapped
1273+
1274+ else :
1275+ run = functools .wraps (func )(run )
1276+ return unittest .expectedFailure (run ) if xfail else run
1277+
1278+
1279+ def _get_file_code (func ):
12061280 wrapped = inspect .unwrap (func )
12071281 is_async = inspect .iscoroutinefunction (wrapped )
12081282
@@ -1246,14 +1320,35 @@ class TestModel(_testbase.{base_class_name}):
12461320{ textwrap .indent (dedented_body , " " * 2 )}
12471321 """
12481322
1249- def run (self ):
1250- d = type (self ).tmp_model_dir .name
1323+ return source_code
1324+
1325+
1326+ def _typecheck_class (cls , funcs ):
1327+ """Extract all the typecheckable functions from a class and typecheck.
1328+
1329+ Run both mypy and pyright, then stash the results where the
1330+ individual functions will deal with them.
1331+ """
1332+
1333+ contents = [(func .__name__ , _get_file_code (func )) for func in funcs ]
1334+ cls ._mypy_errors = {}
1335+ cls ._pyright_errors = {}
1336+
1337+ orig_setUpClass = cls .setUpClass
1338+
1339+ def _setUp (cls ):
1340+ orig_setUpClass ()
1341+
1342+ d = cls .tmp_model_dir .name
12511343
1252- testfn = pathlib .Path (d ) / "test.py"
12531344 inifn = pathlib .Path (d ) / "mypy.ini"
1345+ tdir = pathlib .Path (d ) / "tests"
1346+ os .mkdir (tdir )
12541347
1255- with open (testfn , "wt" ) as f :
1256- f .write (source_code )
1348+ for name , code in contents :
1349+ testfn = tdir / (name + ".py" )
1350+ with open (testfn , "wt" ) as f :
1351+ f .write (code )
12571352
12581353 with open (inifn , "wt" ) as f :
12591354 f .write (
@@ -1283,71 +1378,63 @@ def run(self):
12831378 inifn ,
12841379 "--cache-dir" ,
12851380 str (pathlib .Path (__file__ ).parent .parent / ".mypy_cache" ),
1286- testfn ,
1381+ tdir ,
12871382 ]
1288-
12891383 res = subprocess .run (
12901384 cmd ,
12911385 capture_output = True ,
12921386 check = False ,
12931387 cwd = inifn .parent ,
12941388 )
1295- finally :
1296- inifn .unlink ()
1297- testfn .unlink ()
12981389
1299- if res .returncode != 0 :
1300- lines = source_code .split ("\n " )
1301- pad_width = max (2 , len (str (len (lines ))))
1302- source_code_numbered = "\n " .join (
1303- f"{ i + 1 :0{pad_width }d} : { line } "
1304- for i , line in enumerate (lines )
1305- )
1390+ cmd = [
1391+ sys .executable ,
1392+ "-m" ,
1393+ "pyright" ,
1394+ tdir ,
1395+ ]
13061396
1307- raise RuntimeError (
1308- f"mypy check failed for { func . __name__ } "
1309- f" \n \n test code: \n { source_code_numbered } "
1310- f" \n \n mypy stdout: \n { res . stdout . decode () } "
1311- f" \n \n mypy stderr: \n { res . stderr . decode () } "
1397+ pyright_res = subprocess . run (
1398+ cmd ,
1399+ capture_output = True ,
1400+ check = False ,
1401+ cwd = inifn . parent ,
13121402 )
1403+ finally :
1404+ inifn .unlink ()
1405+ shutil .rmtree (tdir )
13131406
1314- types = []
1315-
1316- out = res .stdout .decode ().split ("\n " )
1317- for line in out :
1318- if m := re .match (r'.*Revealed type is "(?P<name>[^"]+)".*' , line ):
1319- types .append (m .group ("name" ))
1320-
1321- def reveal_type (_ , * , ncalls = [0 ], types = types ): # noqa: B006
1322- ncalls [0 ] += 1
1323- try :
1324- return types [ncalls [0 ] - 1 ]
1325- except IndexError :
1326- return None
1327-
1328- if is_async :
1329- # `func` is the result of `TestCaseMeta.wrap()`, so we
1330- # want to use the coroutine function that was there in
1331- # the beginning, so let's use `wrapped`
1332- wrapped .__globals__ ["reveal_type" ] = reveal_type
1333- return wrapped (self )
1334- else :
1335- func .__globals__ ["reveal_type" ] = reveal_type
1336- return func (self )
1337-
1338- if is_async :
1339-
1340- @functools .wraps (wrapped )
1341- async def runner (self , run = run ):
1342- coro = run (self )
1343- await coro
1407+ # Parse out mypy errors and assign them to test cases.
1408+ # mypy lines are all prefixed with file name
1409+ start = 'tests' + os .sep
1410+ for line in res .stdout .decode ('utf-8' ).split ('\n ' ):
1411+ if not (start in line and '.py' in line ):
1412+ continue
1413+ name = line .split (start )[1 ].split ('.' )[0 ]
1414+ cls ._mypy_errors [name ] = (
1415+ cls ._mypy_errors .get (name , '' ) + line + '\n '
1416+ )
13441417
1345- rewrapped = TestCaseMeta .wrap (runner )
1346- return unittest .expectedFailure (rewrapped ) if xfail else rewrapped
1418+ # Parse out mypy errors and assign them to test cases.
1419+ # Pyright lines have file name groups started by the name, and
1420+ # then subsequent lines are indented. Messages can be
1421+ # multiline. They have a --outputjson mode that oculd save us
1422+ # trouble here but would give us some more trouble on the
1423+ # formatting side so whatever.
1424+ name = None
1425+ cur_lines = ''
1426+ for line in pyright_res .stdout .decode ('utf-8' ).split ('\n ' ):
1427+ if line .startswith ('/' ):
1428+ if name :
1429+ cls ._pyright_errors [name ] = cur_lines
1430+ cur_lines = ''
1431+ name = line .split (start )[1 ].split ('.' )[0 ]
1432+ else :
1433+ cur_lines += line + '\n '
1434+ if name :
1435+ cls ._pyright_errors [name ] = cur_lines
13471436
1348- else :
1349- run = functools .wraps (func )(run )
1350- return unittest .expectedFailure (run ) if xfail else run
1437+ cls .setUpClass = classmethod (_setUp )
13511438
13521439
13531440def typecheck (arg ):
@@ -1357,20 +1444,21 @@ def typecheck(arg):
13571444 schemas and the query builder APIs.
13581445 """
13591446 # Please don't add arguments to this decorator, thank you.
1360- if isinstance (arg , type ):
1361- for func in arg .__dict__ .values ():
1362- if not isinstance (func , types .FunctionType ):
1363- continue
1364- if not func .__name__ .startswith ("test_" ):
1365- continue
1366- new_func = typecheck (func )
1367- setattr (arg , func .__name__ , new_func )
1368- return arg
1369- else :
1370- assert isinstance (arg , types .FunctionType )
1447+ assert isinstance (arg , type )
1448+ all_checked = []
1449+ for func in arg .__dict__ .values ():
1450+ if not isinstance (func , types .FunctionType ):
1451+ continue
1452+ if not func .__name__ .startswith ("test_" ):
1453+ continue
13711454 if hasattr (arg , "_typecheck_skipped" ):
1372- return arg
1373- return _typecheck (arg )
1455+ continue
1456+ all_checked .append (func )
1457+ new_func = _typecheck (arg , func )
1458+ setattr (arg , func .__name__ , new_func )
1459+
1460+ _typecheck_class (arg , all_checked )
1461+ return arg
13741462
13751463
13761464def skip_typecheck (arg ):
0 commit comments