@@ -1191,7 +1191,86 @@ def gen_lock_key():
11911191 return os .getpid () * 1000 + _lock_cnt
11921192
11931193
1194- def _typecheck (func , * , xfail = False ):
1194+ def _typecheck (cls , func , * , xfail = False ):
1195+ wrapped = inspect .unwrap (func )
1196+ is_async = inspect .iscoroutinefunction (wrapped )
1197+
1198+ def run (self ):
1199+ # We already ran the typechecker on everything, so now inspect
1200+ # the results for this test.
1201+
1202+ mypy_output = cls ._mypy_errors .get (func .__name__ , '' )
1203+ if 'error:' in mypy_output :
1204+ source_code = _get_file_code (func )
1205+ lines = source_code .split ("\n " )
1206+ pad_width = max (2 , len (str (len (lines ))))
1207+ source_code_numbered = "\n " .join (
1208+ f"{ i + 1 :0{pad_width }d} : { line } "
1209+ for i , line in enumerate (lines )
1210+ )
1211+
1212+ raise RuntimeError (
1213+ f"mypy check failed for { func .__name__ } "
1214+ f"\n \n test code:\n { source_code_numbered } "
1215+ f"\n \n mypy stdout:\n { mypy_output } "
1216+ )
1217+
1218+ pyright_output = cls ._pyright_errors .get (func .__name__ , '' )
1219+ if '- error:' in pyright_output :
1220+ source_code = _get_file_code (func )
1221+ lines = source_code .split ("\n " )
1222+ pad_width = max (2 , len (str (len (lines ))))
1223+ source_code_numbered = "\n " .join (
1224+ f"{ i + 1 :0{pad_width }d} : { line } "
1225+ for i , line in enumerate (lines )
1226+ )
1227+
1228+ raise RuntimeError (
1229+ f"pyright check failed for { func .__name__ } "
1230+ f"\n \n test code:\n { source_code_numbered } "
1231+ f"\n \n pyright stdout:\n { pyright_output } "
1232+ )
1233+
1234+ types = []
1235+
1236+ out = mypy_output .split ("\n " )
1237+ for line in out :
1238+ if m := re .match (r'.*Revealed type is "(?P<name>[^"]+)".*' , line ):
1239+ types .append (m .group ("name" ))
1240+
1241+ def reveal_type (_ , * , ncalls = [0 ], types = types ): # noqa: B006
1242+ ncalls [0 ] += 1
1243+ try :
1244+ return types [ncalls [0 ] - 1 ]
1245+ except IndexError :
1246+ return None
1247+
1248+ if is_async :
1249+ # `func` is the result of `TestCaseMeta.wrap()`, so we
1250+ # want to use the coroutine function that was there in
1251+ # the beginning, so let's use `wrapped`
1252+ wrapped .__globals__ ["reveal_type" ] = reveal_type
1253+ return wrapped (self )
1254+ else :
1255+ func .__globals__ ["reveal_type" ] = reveal_type
1256+ return func (self )
1257+
1258+ if is_async :
1259+
1260+ @functools .wraps (wrapped )
1261+ async def runner (self , run = run ):
1262+ coro = run (self )
1263+ await coro
1264+
1265+ rewrapped = TestCaseMeta .wrap (runner )
1266+ return unittest .expectedFailure (rewrapped ) if xfail else rewrapped
1267+
1268+ else :
1269+ run = functools .wraps (func )(run )
1270+ return unittest .expectedFailure (run ) if xfail else run
1271+
1272+
1273+ def _get_file_code (func ):
11951274 wrapped = inspect .unwrap (func )
11961275 is_async = inspect .iscoroutinefunction (wrapped )
11971276
@@ -1235,14 +1314,35 @@ class TestModel(_testbase.{base_class_name}):
12351314{ textwrap .indent (dedented_body , " " * 2 )}
12361315 """
12371316
1238- def run (self ):
1239- d = type (self ).tmp_model_dir .name
1317+ return source_code
1318+
1319+
1320+ def _typecheck_class (cls , funcs ):
1321+ """Extract all the typecheckable functions from a class and typecheck.
1322+
1323+ Run both mypy and pyright, then stash the results where the
1324+ individual functions will deal with them.
1325+ """
1326+
1327+ contents = [(func .__name__ , _get_file_code (func )) for func in funcs ]
1328+ cls ._mypy_errors = {}
1329+ cls ._pyright_errors = {}
1330+
1331+ orig_setUpClass = cls .setUpClass
1332+
1333+ def _setUp (cls ):
1334+ orig_setUpClass ()
1335+
1336+ d = cls .tmp_model_dir .name
12401337
1241- testfn = pathlib .Path (d ) / "test.py"
12421338 inifn = pathlib .Path (d ) / "mypy.ini"
1339+ tdir = pathlib .Path (d ) / "tests"
1340+ os .mkdir (tdir )
12431341
1244- with open (testfn , "wt" ) as f :
1245- f .write (source_code )
1342+ for name , code in contents :
1343+ testfn = tdir / (name + ".py" )
1344+ with open (testfn , "wt" ) as f :
1345+ f .write (code )
12461346
12471347 with open (inifn , "wt" ) as f :
12481348 f .write (
@@ -1272,71 +1372,63 @@ def run(self):
12721372 inifn ,
12731373 "--cache-dir" ,
12741374 str (pathlib .Path (__file__ ).parent .parent / ".mypy_cache" ),
1275- testfn ,
1375+ tdir ,
12761376 ]
1277-
12781377 res = subprocess .run (
12791378 cmd ,
12801379 capture_output = True ,
12811380 check = False ,
12821381 cwd = inifn .parent ,
12831382 )
1284- finally :
1285- inifn .unlink ()
1286- testfn .unlink ()
12871383
1288- if res .returncode != 0 :
1289- lines = source_code .split ("\n " )
1290- pad_width = max (2 , len (str (len (lines ))))
1291- source_code_numbered = "\n " .join (
1292- f"{ i + 1 :0{pad_width }d} : { line } "
1293- for i , line in enumerate (lines )
1294- )
1384+ cmd = [
1385+ sys .executable ,
1386+ "-m" ,
1387+ "pyright" ,
1388+ tdir ,
1389+ ]
12951390
1296- raise RuntimeError (
1297- f"mypy check failed for { func . __name__ } "
1298- f" \n \n test code: \n { source_code_numbered } "
1299- f" \n \n mypy stdout: \n { res . stdout . decode () } "
1300- f" \n \n mypy stderr: \n { res . stderr . decode () } "
1391+ pyright_res = subprocess . run (
1392+ cmd ,
1393+ capture_output = True ,
1394+ check = False ,
1395+ cwd = inifn . parent ,
13011396 )
1397+ finally :
1398+ inifn .unlink ()
1399+ shutil .rmtree (tdir )
13021400
1303- types = []
1304-
1305- out = res .stdout .decode ().split ("\n " )
1306- for line in out :
1307- if m := re .match (r'.*Revealed type is "(?P<name>[^"]+)".*' , line ):
1308- types .append (m .group ("name" ))
1309-
1310- def reveal_type (_ , * , ncalls = [0 ], types = types ): # noqa: B006
1311- ncalls [0 ] += 1
1312- try :
1313- return types [ncalls [0 ] - 1 ]
1314- except IndexError :
1315- return None
1316-
1317- if is_async :
1318- # `func` is the result of `TestCaseMeta.wrap()`, so we
1319- # want to use the coroutine function that was there in
1320- # the beginning, so let's use `wrapped`
1321- wrapped .__globals__ ["reveal_type" ] = reveal_type
1322- return wrapped (self )
1323- else :
1324- func .__globals__ ["reveal_type" ] = reveal_type
1325- return func (self )
1326-
1327- if is_async :
1328-
1329- @functools .wraps (wrapped )
1330- async def runner (self , run = run ):
1331- coro = run (self )
1332- await coro
1401+ # Parse out mypy errors and assign them to test cases.
1402+ # mypy lines are all prefixed with file name
1403+ for line in res .stdout .decode ('utf-8' ).split ('\n ' ):
1404+ start = 'tests' + os .sep
1405+ if not (line .startswith (start ) and '.py' in line ):
1406+ continue
1407+ name = line .removeprefix (start ).split ('.' )[0 ]
1408+ cls ._mypy_errors [name ] = (
1409+ cls ._mypy_errors .get (name , '' ) + line + '\n '
1410+ )
13331411
1334- rewrapped = TestCaseMeta .wrap (runner )
1335- return unittest .expectedFailure (rewrapped ) if xfail else rewrapped
1412+ # Parse out mypy errors and assign them to test cases.
1413+ # Pyright lines have file name groups started by the name, and
1414+ # then subsequent lines are indented. Messages can be
1415+ # multiline. They have a --outputjson mode that oculd save us
1416+ # trouble here but would give us some more trouble on the
1417+ # formatting side so whatever.
1418+ name = None
1419+ cur_lines = ''
1420+ for line in pyright_res .stdout .decode ('utf-8' ).split ('\n ' ):
1421+ if line .startswith ('/' ):
1422+ if name :
1423+ cls ._pyright_errors [name ] = cur_lines
1424+ cur_lines = ''
1425+ name = line .split ('tests' + os .sep )[1 ].split ('.' )[0 ]
1426+ else :
1427+ cur_lines += line + '\n '
1428+ if name :
1429+ cls ._pyright_errors [name ] = cur_lines
13361430
1337- else :
1338- run = functools .wraps (func )(run )
1339- return unittest .expectedFailure (run ) if xfail else run
1431+ cls .setUpClass = classmethod (_setUp )
13401432
13411433
13421434def typecheck (arg ):
@@ -1346,20 +1438,21 @@ def typecheck(arg):
13461438 schemas and the query builder APIs.
13471439 """
13481440 # Please don't add arguments to this decorator, thank you.
1349- if isinstance (arg , type ):
1350- for func in arg .__dict__ .values ():
1351- if not isinstance (func , types .FunctionType ):
1352- continue
1353- if not func .__name__ .startswith ("test_" ):
1354- continue
1355- new_func = typecheck (func )
1356- setattr (arg , func .__name__ , new_func )
1357- return arg
1358- else :
1359- assert isinstance (arg , types .FunctionType )
1441+ assert isinstance (arg , type )
1442+ all_checked = []
1443+ for func in arg .__dict__ .values ():
1444+ if not isinstance (func , types .FunctionType ):
1445+ continue
1446+ if not func .__name__ .startswith ("test_" ):
1447+ continue
13601448 if hasattr (arg , "_typecheck_skipped" ):
1361- return arg
1362- return _typecheck (arg )
1449+ continue
1450+ all_checked .append (func )
1451+ new_func = _typecheck (arg , func )
1452+ setattr (arg , func .__name__ , new_func )
1453+
1454+ _typecheck_class (arg , all_checked )
1455+ return arg
13631456
13641457
13651458def skip_typecheck (arg ):
0 commit comments