Skip to content

Commit 125677e

Browse files
committed
Run pyright as part of @tb.typecheck, batch all typechecks
pyright doesn't have a cache, so I didn't want to rerun pyright fully for every test. So instead I have us generate one file per test case as part of test suite setup, typecheck all the files at once, and then check the messages for each file in the individual tests. This speeds things up for mypy too quite a bit, because loading all the caches is actually pretty expensive too. Fixes #878. @elprans Two tests didn't work for reasons involving splats in shapes.
1 parent e1e5ef3 commit 125677e

File tree

2 files changed

+174
-74
lines changed

2 files changed

+174
-74
lines changed

gel/_testbase.py

Lines changed: 164 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,7 +1191,86 @@ def gen_lock_key():
11911191
return os.getpid() * 1000 + _lock_cnt
11921192

11931193

1194-
def _typecheck(func, *, xfail=False):
1194+
def _typecheck(cls, func, *, xfail=False):
1195+
wrapped = inspect.unwrap(func)
1196+
is_async = inspect.iscoroutinefunction(wrapped)
1197+
1198+
def run(self):
1199+
# We already ran the typechecker on everything, so now inspect
1200+
# the results for this test.
1201+
1202+
mypy_output = cls._mypy_errors.get(func.__name__, '')
1203+
if 'error:' in mypy_output:
1204+
source_code = _get_file_code(func)
1205+
lines = source_code.split("\n")
1206+
pad_width = max(2, len(str(len(lines))))
1207+
source_code_numbered = "\n".join(
1208+
f"{i + 1:0{pad_width}d}: {line}"
1209+
for i, line in enumerate(lines)
1210+
)
1211+
1212+
raise RuntimeError(
1213+
f"mypy check failed for {func.__name__} "
1214+
f"\n\ntest code:\n{source_code_numbered}"
1215+
f"\n\nmypy stdout:\n{mypy_output}"
1216+
)
1217+
1218+
pyright_output = cls._pyright_errors.get(func.__name__, '')
1219+
if '- error:' in pyright_output:
1220+
source_code = _get_file_code(func)
1221+
lines = source_code.split("\n")
1222+
pad_width = max(2, len(str(len(lines))))
1223+
source_code_numbered = "\n".join(
1224+
f"{i + 1:0{pad_width}d}: {line}"
1225+
for i, line in enumerate(lines)
1226+
)
1227+
1228+
raise RuntimeError(
1229+
f"pyright check failed for {func.__name__} "
1230+
f"\n\ntest code:\n{source_code_numbered}"
1231+
f"\n\npyright stdout:\n{pyright_output}"
1232+
)
1233+
1234+
types = []
1235+
1236+
out = mypy_output.split("\n")
1237+
for line in out:
1238+
if m := re.match(r'.*Revealed type is "(?P<name>[^"]+)".*', line):
1239+
types.append(m.group("name"))
1240+
1241+
def reveal_type(_, *, ncalls=[0], types=types): # noqa: B006
1242+
ncalls[0] += 1
1243+
try:
1244+
return types[ncalls[0] - 1]
1245+
except IndexError:
1246+
return None
1247+
1248+
if is_async:
1249+
# `func` is the result of `TestCaseMeta.wrap()`, so we
1250+
# want to use the coroutine function that was there in
1251+
# the beginning, so let's use `wrapped`
1252+
wrapped.__globals__["reveal_type"] = reveal_type
1253+
return wrapped(self)
1254+
else:
1255+
func.__globals__["reveal_type"] = reveal_type
1256+
return func(self)
1257+
1258+
if is_async:
1259+
1260+
@functools.wraps(wrapped)
1261+
async def runner(self, run=run):
1262+
coro = run(self)
1263+
await coro
1264+
1265+
rewrapped = TestCaseMeta.wrap(runner)
1266+
return unittest.expectedFailure(rewrapped) if xfail else rewrapped
1267+
1268+
else:
1269+
run = functools.wraps(func)(run)
1270+
return unittest.expectedFailure(run) if xfail else run
1271+
1272+
1273+
def _get_file_code(func):
11951274
wrapped = inspect.unwrap(func)
11961275
is_async = inspect.iscoroutinefunction(wrapped)
11971276

@@ -1235,14 +1314,35 @@ class TestModel(_testbase.{base_class_name}):
12351314
{textwrap.indent(dedented_body, " " * 2)}
12361315
"""
12371316

1238-
def run(self):
1239-
d = type(self).tmp_model_dir.name
1317+
return source_code
1318+
1319+
1320+
def _typecheck_class(cls, funcs):
1321+
"""Extract all the typecheckable functions from a class and typecheck.
1322+
1323+
Run both mypy and pyright, then stash the results where the
1324+
individual functions will deal with them.
1325+
"""
1326+
1327+
contents = [(func.__name__, _get_file_code(func)) for func in funcs]
1328+
cls._mypy_errors = {}
1329+
cls._pyright_errors = {}
1330+
1331+
orig_setUpClass = cls.setUpClass
1332+
1333+
def _setUp(cls):
1334+
orig_setUpClass()
1335+
1336+
d = cls.tmp_model_dir.name
12401337

1241-
testfn = pathlib.Path(d) / "test.py"
12421338
inifn = pathlib.Path(d) / "mypy.ini"
1339+
tdir = pathlib.Path(d) / "tests"
1340+
os.mkdir(tdir)
12431341

1244-
with open(testfn, "wt") as f:
1245-
f.write(source_code)
1342+
for name, code in contents:
1343+
testfn = tdir / (name + ".py")
1344+
with open(testfn, "wt") as f:
1345+
f.write(code)
12461346

12471347
with open(inifn, "wt") as f:
12481348
f.write(
@@ -1272,71 +1372,63 @@ def run(self):
12721372
inifn,
12731373
"--cache-dir",
12741374
str(pathlib.Path(__file__).parent.parent / ".mypy_cache"),
1275-
testfn,
1375+
tdir,
12761376
]
1277-
12781377
res = subprocess.run(
12791378
cmd,
12801379
capture_output=True,
12811380
check=False,
12821381
cwd=inifn.parent,
12831382
)
1284-
finally:
1285-
inifn.unlink()
1286-
testfn.unlink()
12871383

1288-
if res.returncode != 0:
1289-
lines = source_code.split("\n")
1290-
pad_width = max(2, len(str(len(lines))))
1291-
source_code_numbered = "\n".join(
1292-
f"{i + 1:0{pad_width}d}: {line}"
1293-
for i, line in enumerate(lines)
1294-
)
1384+
cmd = [
1385+
sys.executable,
1386+
"-m",
1387+
"pyright",
1388+
tdir,
1389+
]
12951390

1296-
raise RuntimeError(
1297-
f"mypy check failed for {func.__name__} "
1298-
f"\n\ntest code:\n{source_code_numbered}"
1299-
f"\n\nmypy stdout:\n{res.stdout.decode()}"
1300-
f"\n\nmypy stderr:\n{res.stderr.decode()}"
1391+
pyright_res = subprocess.run(
1392+
cmd,
1393+
capture_output=True,
1394+
check=False,
1395+
cwd=inifn.parent,
13011396
)
1397+
finally:
1398+
inifn.unlink()
1399+
shutil.rmtree(tdir)
13021400

1303-
types = []
1304-
1305-
out = res.stdout.decode().split("\n")
1306-
for line in out:
1307-
if m := re.match(r'.*Revealed type is "(?P<name>[^"]+)".*', line):
1308-
types.append(m.group("name"))
1309-
1310-
def reveal_type(_, *, ncalls=[0], types=types): # noqa: B006
1311-
ncalls[0] += 1
1312-
try:
1313-
return types[ncalls[0] - 1]
1314-
except IndexError:
1315-
return None
1316-
1317-
if is_async:
1318-
# `func` is the result of `TestCaseMeta.wrap()`, so we
1319-
# want to use the coroutine function that was there in
1320-
# the beginning, so let's use `wrapped`
1321-
wrapped.__globals__["reveal_type"] = reveal_type
1322-
return wrapped(self)
1323-
else:
1324-
func.__globals__["reveal_type"] = reveal_type
1325-
return func(self)
1326-
1327-
if is_async:
1328-
1329-
@functools.wraps(wrapped)
1330-
async def runner(self, run=run):
1331-
coro = run(self)
1332-
await coro
1401+
# Parse out mypy errors and assign them to test cases.
1402+
# mypy lines are all prefixed with file name
1403+
for line in res.stdout.decode('utf-8').split('\n'):
1404+
start = 'tests' + os.sep
1405+
if not (line.startswith(start) and '.py' in line):
1406+
continue
1407+
name = line.removeprefix(start).split('.')[0]
1408+
cls._mypy_errors[name] = (
1409+
cls._mypy_errors.get(name, '') + line + '\n'
1410+
)
13331411

1334-
rewrapped = TestCaseMeta.wrap(runner)
1335-
return unittest.expectedFailure(rewrapped) if xfail else rewrapped
1412+
# Parse out mypy errors and assign them to test cases.
1413+
# Pyright lines have file name groups started by the name, and
1414+
# then subsequent lines are indented. Messages can be
1415+
# multiline. They have a --outputjson mode that oculd save us
1416+
# trouble here but would give us some more trouble on the
1417+
# formatting side so whatever.
1418+
name = None
1419+
cur_lines = ''
1420+
for line in pyright_res.stdout.decode('utf-8').split('\n'):
1421+
if line.startswith('/'):
1422+
if name:
1423+
cls._pyright_errors[name] = cur_lines
1424+
cur_lines = ''
1425+
name = line.split('tests' + os.sep)[1].split('.')[0]
1426+
else:
1427+
cur_lines += line + '\n'
1428+
if name:
1429+
cls._pyright_errors[name] = cur_lines
13361430

1337-
else:
1338-
run = functools.wraps(func)(run)
1339-
return unittest.expectedFailure(run) if xfail else run
1431+
cls.setUpClass = classmethod(_setUp)
13401432

13411433

13421434
def typecheck(arg):
@@ -1346,20 +1438,21 @@ def typecheck(arg):
13461438
schemas and the query builder APIs.
13471439
"""
13481440
# Please don't add arguments to this decorator, thank you.
1349-
if isinstance(arg, type):
1350-
for func in arg.__dict__.values():
1351-
if not isinstance(func, types.FunctionType):
1352-
continue
1353-
if not func.__name__.startswith("test_"):
1354-
continue
1355-
new_func = typecheck(func)
1356-
setattr(arg, func.__name__, new_func)
1357-
return arg
1358-
else:
1359-
assert isinstance(arg, types.FunctionType)
1441+
assert isinstance(arg, type)
1442+
all_checked = []
1443+
for func in arg.__dict__.values():
1444+
if not isinstance(func, types.FunctionType):
1445+
continue
1446+
if not func.__name__.startswith("test_"):
1447+
continue
13601448
if hasattr(arg, "_typecheck_skipped"):
1361-
return arg
1362-
return _typecheck(arg)
1449+
continue
1450+
all_checked.append(func)
1451+
new_func = _typecheck(arg, func)
1452+
setattr(arg, func.__name__, new_func)
1453+
1454+
_typecheck_class(arg, all_checked)
1455+
return arg
13631456

13641457

13651458
def skip_typecheck(arg):

tests/test_model_generator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class MockPointer(typing.NamedTuple):
6565

6666

6767
@tb.typecheck
68-
class TestModelGenerator(tb.ModelTestCase):
68+
class TestModelGeneratorMain(tb.ModelTestCase):
6969
SCHEMA = os.path.join(os.path.dirname(__file__), "dbsetup", "orm.gel")
7070

7171
SETUP = os.path.join(os.path.dirname(__file__), "dbsetup", "orm.edgeql")
@@ -1575,7 +1575,8 @@ def test_modelgen_data_unpack_polymorphic(self):
15751575

15761576
q = default.Named.select(
15771577
"*",
1578-
*default.UserGroup,
1578+
# FIXME: pyright fails here
1579+
*default.UserGroup, # pyright: ignore
15791580
)
15801581

15811582
for item in self.client.query(q):
@@ -1823,7 +1824,8 @@ def test_modelgen_save_01(self):
18231824

18241825
pq = (
18251826
default.Post.select(
1826-
*default.Post,
1827+
# FIXME: pyright fails here
1828+
*default.Post, # pyright: ignore
18271829
author=True,
18281830
)
18291831
.filter(lambda p: p.body == "I'm Alice")
@@ -4140,9 +4142,12 @@ def test_modelgen_save_reload_links_08(self):
41404142

41414143
self.assertEqual({u.name for u in g.users}, {"0aaa", "1aaa", "2aaa"})
41424144

4145+
u = None
41434146
for u in g.users:
41444147
u.name += "bbb"
41454148

4149+
assert u
4150+
41464151
g.users.remove(u)
41474152
g.users.add(default.User(name="new"))
41484153

@@ -5088,13 +5093,15 @@ def test_modelgen_linkprops_09(self):
50885093
gs = self.client.get(
50895094
default.GameSession.select("*", players=True).filter(num=123)
50905095
)
5096+
alice = None
50915097
for p in gs.players:
50925098
if p.name == "Alice":
50935099
alice = p
50945100
self.assertEqual(alice.__linkprops__.is_tall_enough, False)
50955101
alice.__linkprops__.is_tall_enough = True
50965102

50975103
self.client.sync(gs)
5104+
assert alice
50985105
self.assertEqual(alice.__linkprops__.is_tall_enough, True)
50995106

51005107
def test_modelgen_globals_01(self):

0 commit comments

Comments
 (0)