Skip to content

Commit dac9d09

Browse files
committed
Run pyright as part of @tb.typecheck, batch all typechecks
pyright doesn't have a cache, so I didn't want to rerun pyright fully for every test. So instead I have us generate one file per test case as part of test suite setup, typecheck all the files at once, and then check the messages for each file in the individual tests. This speeds things up for mypy too quite a bit, because loading all the caches is actually pretty expensive too. Fixes #878. @elprans Two tests didn't work for reasons involving splats in shapes.
1 parent 8178169 commit dac9d09

File tree

2 files changed

+169
-74
lines changed

2 files changed

+169
-74
lines changed

gel/_testbase.py

Lines changed: 159 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,81 @@ def gen_lock_key():
12021202
return os.getpid() * 1000 + _lock_cnt
12031203

12041204

1205-
def _typecheck(func, *, xfail=False):
1205+
def _typecheck(cls, func, *, xfail=False):
1206+
wrapped = inspect.unwrap(func)
1207+
is_async = inspect.iscoroutinefunction(wrapped)
1208+
1209+
def run(self):
1210+
# We already ran the typechecker on everything, so now inspect
1211+
# the results for this test.
1212+
1213+
mypy_output = cls._mypy_errors.get(func.__name__, '')
1214+
mypy_error = 'error:' in mypy_output
1215+
1216+
pyright_output = cls._pyright_errors.get(func.__name__, '')
1217+
pyright_error = '- error:' in pyright_output
1218+
1219+
if mypy_error or pyright_error:
1220+
source_code = _get_file_code(func)
1221+
lines = source_code.split("\n")
1222+
pad_width = max(2, len(str(len(lines))))
1223+
source_code_numbered = "\n".join(
1224+
f"{i + 1:0{pad_width}d}: {line}"
1225+
for i, line in enumerate(lines)
1226+
)
1227+
1228+
if mypy_error:
1229+
raise RuntimeError(
1230+
f"mypy check failed for {func.__name__} "
1231+
f"\n\ntest code:\n{source_code_numbered}"
1232+
f"\n\nmypy stdout:\n{mypy_output}"
1233+
)
1234+
1235+
if pyright_error:
1236+
raise RuntimeError(
1237+
f"pyright check failed for {func.__name__} "
1238+
f"\n\ntest code:\n{source_code_numbered}"
1239+
f"\n\npyright stdout:\n{pyright_output}"
1240+
)
1241+
1242+
types = []
1243+
for line in mypy_output.split("\n"):
1244+
if m := re.match(r'.*Revealed type is "(?P<name>[^"]+)".*', line):
1245+
types.append(m.group("name"))
1246+
1247+
def reveal_type(_, *, ncalls=[0], types=types): # noqa: B006
1248+
ncalls[0] += 1
1249+
try:
1250+
return types[ncalls[0] - 1]
1251+
except IndexError:
1252+
return None
1253+
1254+
if is_async:
1255+
# `func` is the result of `TestCaseMeta.wrap()`, so we
1256+
# want to use the coroutine function that was there in
1257+
# the beginning, so let's use `wrapped`
1258+
wrapped.__globals__["reveal_type"] = reveal_type
1259+
return wrapped(self)
1260+
else:
1261+
func.__globals__["reveal_type"] = reveal_type
1262+
return func(self)
1263+
1264+
if is_async:
1265+
1266+
@functools.wraps(wrapped)
1267+
async def runner(self, run=run):
1268+
coro = run(self)
1269+
await coro
1270+
1271+
rewrapped = TestCaseMeta.wrap(runner)
1272+
return unittest.expectedFailure(rewrapped) if xfail else rewrapped
1273+
1274+
else:
1275+
run = functools.wraps(func)(run)
1276+
return unittest.expectedFailure(run) if xfail else run
1277+
1278+
1279+
def _get_file_code(func):
12061280
wrapped = inspect.unwrap(func)
12071281
is_async = inspect.iscoroutinefunction(wrapped)
12081282

@@ -1246,14 +1320,35 @@ class TestModel(_testbase.{base_class_name}):
12461320
{textwrap.indent(dedented_body, " " * 2)}
12471321
"""
12481322

1249-
def run(self):
1250-
d = type(self).tmp_model_dir.name
1323+
return source_code
1324+
1325+
1326+
def _typecheck_class(cls, funcs):
1327+
"""Extract all the typecheckable functions from a class and typecheck.
1328+
1329+
Run both mypy and pyright, then stash the results where the
1330+
individual functions will deal with them.
1331+
"""
1332+
1333+
contents = [(func.__name__, _get_file_code(func)) for func in funcs]
1334+
cls._mypy_errors = {}
1335+
cls._pyright_errors = {}
1336+
1337+
orig_setUpClass = cls.setUpClass
1338+
1339+
def _setUp(cls):
1340+
orig_setUpClass()
1341+
1342+
d = cls.tmp_model_dir.name
12511343

1252-
testfn = pathlib.Path(d) / "test.py"
12531344
inifn = pathlib.Path(d) / "mypy.ini"
1345+
tdir = pathlib.Path(d) / "tests"
1346+
os.mkdir(tdir)
12541347

1255-
with open(testfn, "wt") as f:
1256-
f.write(source_code)
1348+
for name, code in contents:
1349+
testfn = tdir / (name + ".py")
1350+
with open(testfn, "wt") as f:
1351+
f.write(code)
12571352

12581353
with open(inifn, "wt") as f:
12591354
f.write(
@@ -1283,71 +1378,63 @@ def run(self):
12831378
inifn,
12841379
"--cache-dir",
12851380
str(pathlib.Path(__file__).parent.parent / ".mypy_cache"),
1286-
testfn,
1381+
tdir,
12871382
]
1288-
12891383
res = subprocess.run(
12901384
cmd,
12911385
capture_output=True,
12921386
check=False,
12931387
cwd=inifn.parent,
12941388
)
1295-
finally:
1296-
inifn.unlink()
1297-
testfn.unlink()
12981389

1299-
if res.returncode != 0:
1300-
lines = source_code.split("\n")
1301-
pad_width = max(2, len(str(len(lines))))
1302-
source_code_numbered = "\n".join(
1303-
f"{i + 1:0{pad_width}d}: {line}"
1304-
for i, line in enumerate(lines)
1305-
)
1390+
cmd = [
1391+
sys.executable,
1392+
"-m",
1393+
"pyright",
1394+
tdir,
1395+
]
13061396

1307-
raise RuntimeError(
1308-
f"mypy check failed for {func.__name__} "
1309-
f"\n\ntest code:\n{source_code_numbered}"
1310-
f"\n\nmypy stdout:\n{res.stdout.decode()}"
1311-
f"\n\nmypy stderr:\n{res.stderr.decode()}"
1397+
pyright_res = subprocess.run(
1398+
cmd,
1399+
capture_output=True,
1400+
check=False,
1401+
cwd=inifn.parent,
13121402
)
1403+
finally:
1404+
inifn.unlink()
1405+
shutil.rmtree(tdir)
13131406

1314-
types = []
1315-
1316-
out = res.stdout.decode().split("\n")
1317-
for line in out:
1318-
if m := re.match(r'.*Revealed type is "(?P<name>[^"]+)".*', line):
1319-
types.append(m.group("name"))
1320-
1321-
def reveal_type(_, *, ncalls=[0], types=types): # noqa: B006
1322-
ncalls[0] += 1
1323-
try:
1324-
return types[ncalls[0] - 1]
1325-
except IndexError:
1326-
return None
1327-
1328-
if is_async:
1329-
# `func` is the result of `TestCaseMeta.wrap()`, so we
1330-
# want to use the coroutine function that was there in
1331-
# the beginning, so let's use `wrapped`
1332-
wrapped.__globals__["reveal_type"] = reveal_type
1333-
return wrapped(self)
1334-
else:
1335-
func.__globals__["reveal_type"] = reveal_type
1336-
return func(self)
1337-
1338-
if is_async:
1339-
1340-
@functools.wraps(wrapped)
1341-
async def runner(self, run=run):
1342-
coro = run(self)
1343-
await coro
1407+
# Parse out mypy errors and assign them to test cases.
1408+
# mypy lines are all prefixed with file name
1409+
start = 'tests' + os.sep
1410+
for line in res.stdout.decode('utf-8').split('\n'):
1411+
if not (start in line and '.py' in line):
1412+
continue
1413+
name = line.split(start)[1].split('.')[0]
1414+
cls._mypy_errors[name] = (
1415+
cls._mypy_errors.get(name, '') + line + '\n'
1416+
)
13441417

1345-
rewrapped = TestCaseMeta.wrap(runner)
1346-
return unittest.expectedFailure(rewrapped) if xfail else rewrapped
1418+
# Parse out mypy errors and assign them to test cases.
1419+
# Pyright lines have file name groups started by the name, and
1420+
# then subsequent lines are indented. Messages can be
1421+
# multiline. They have a --outputjson mode that oculd save us
1422+
# trouble here but would give us some more trouble on the
1423+
# formatting side so whatever.
1424+
name = None
1425+
cur_lines = ''
1426+
for line in pyright_res.stdout.decode('utf-8').split('\n'):
1427+
if line.startswith('/'):
1428+
if name:
1429+
cls._pyright_errors[name] = cur_lines
1430+
cur_lines = ''
1431+
name = line.split(start)[1].split('.')[0]
1432+
else:
1433+
cur_lines += line + '\n'
1434+
if name:
1435+
cls._pyright_errors[name] = cur_lines
13471436

1348-
else:
1349-
run = functools.wraps(func)(run)
1350-
return unittest.expectedFailure(run) if xfail else run
1437+
cls.setUpClass = classmethod(_setUp)
13511438

13521439

13531440
def typecheck(arg):
@@ -1357,20 +1444,21 @@ def typecheck(arg):
13571444
schemas and the query builder APIs.
13581445
"""
13591446
# Please don't add arguments to this decorator, thank you.
1360-
if isinstance(arg, type):
1361-
for func in arg.__dict__.values():
1362-
if not isinstance(func, types.FunctionType):
1363-
continue
1364-
if not func.__name__.startswith("test_"):
1365-
continue
1366-
new_func = typecheck(func)
1367-
setattr(arg, func.__name__, new_func)
1368-
return arg
1369-
else:
1370-
assert isinstance(arg, types.FunctionType)
1447+
assert isinstance(arg, type)
1448+
all_checked = []
1449+
for func in arg.__dict__.values():
1450+
if not isinstance(func, types.FunctionType):
1451+
continue
1452+
if not func.__name__.startswith("test_"):
1453+
continue
13711454
if hasattr(arg, "_typecheck_skipped"):
1372-
return arg
1373-
return _typecheck(arg)
1455+
continue
1456+
all_checked.append(func)
1457+
new_func = _typecheck(arg, func)
1458+
setattr(arg, func.__name__, new_func)
1459+
1460+
_typecheck_class(arg, all_checked)
1461+
return arg
13741462

13751463

13761464
def skip_typecheck(arg):

tests/test_model_generator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class MockPointer(typing.NamedTuple):
6565

6666

6767
@tb.typecheck
68-
class TestModelGenerator(tb.ModelTestCase):
68+
class TestModelGeneratorMain(tb.ModelTestCase):
6969
SCHEMA = os.path.join(os.path.dirname(__file__), "dbsetup", "orm.gel")
7070

7171
SETUP = os.path.join(os.path.dirname(__file__), "dbsetup", "orm.edgeql")
@@ -1575,7 +1575,8 @@ def test_modelgen_data_unpack_polymorphic(self):
15751575

15761576
q = default.Named.select(
15771577
"*",
1578-
*default.UserGroup,
1578+
# FIXME: pyright fails here
1579+
*default.UserGroup, # pyright: ignore
15791580
)
15801581

15811582
for item in self.client.query(q):
@@ -1823,7 +1824,8 @@ def test_modelgen_save_01(self):
18231824

18241825
pq = (
18251826
default.Post.select(
1826-
*default.Post,
1827+
# FIXME: pyright fails here
1828+
*default.Post, # pyright: ignore
18271829
author=True,
18281830
)
18291831
.filter(lambda p: p.body == "I'm Alice")
@@ -4133,6 +4135,7 @@ def test_modelgen_save_reload_links_08(self):
41334135

41344136
self.assertEqual({u.name for u in g.users}, {"0", "1", "2"})
41354137

4138+
u = None
41364139
for u in g.users:
41374140
u.name += "aaa"
41384141

@@ -4143,6 +4146,8 @@ def test_modelgen_save_reload_links_08(self):
41434146
for u in g.users:
41444147
u.name += "bbb"
41454148

4149+
assert u
4150+
41464151
g.users.remove(u)
41474152
g.users.add(default.User(name="new"))
41484153

@@ -5088,13 +5093,15 @@ def test_modelgen_linkprops_09(self):
50885093
gs = self.client.get(
50895094
default.GameSession.select("*", players=True).filter(num=123)
50905095
)
5096+
alice = None
50915097
for p in gs.players:
50925098
if p.name == "Alice":
50935099
alice = p
50945100
self.assertEqual(alice.__linkprops__.is_tall_enough, False)
50955101
alice.__linkprops__.is_tall_enough = True
50965102

50975103
self.client.sync(gs)
5104+
assert alice
50985105
self.assertEqual(alice.__linkprops__.is_tall_enough, True)
50995106

51005107
def test_modelgen_globals_01(self):

0 commit comments

Comments
 (0)