Skip to content

Commit 573b8b9

Browse files
author
Paolo Tranquilli
committed
Merge branch 'rust-experiment' into redsun82/rust-ci
2 parents c979a94 + 56e1278 commit 573b8b9

File tree

26 files changed

+602
-455
lines changed

26 files changed

+602
-455
lines changed

misc/codegen/generators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
1+
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen
22

33

44
def generate(target, opts, renderer):

misc/codegen/generators/qlgen.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import subprocess
2727
import typing
2828
import itertools
29+
import os
2930

3031
import inflection
3132

@@ -287,7 +288,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
287288
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
288289

289290

290-
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
291+
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
291292
return "qltest_skip" in cls.pragmas or not (
292293
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
293294
cls, lookup)
@@ -370,8 +371,10 @@ def generate(opts, renderer):
370371

371372
imports = {}
372373
generated_import_prefix = get_import(out, opts.root_dir)
374+
registry = opts.generated_registry or pathlib.Path(
375+
os.path.commonpath((out, stub_out, test_out)), ".generated.list")
373376

374-
with renderer.manage(generated=generated, stubs=stubs, registry=opts.generated_registry,
377+
with renderer.manage(generated=generated, stubs=stubs, registry=registry,
375378
force=opts.force) as renderer:
376379

377380
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
@@ -413,7 +416,7 @@ def generate(opts, renderer):
413416

414417
if test_out:
415418
for c in data.classes.values():
416-
if _should_skip_qltest(c, data.classes):
419+
if should_skip_qltest(c, data.classes):
417420
continue
418421
test_with = data.classes[c.test_with] if c.test_with else c
419422
test_dir = test_out / test_with.group / test_with.name

misc/codegen/generators/rustgen.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
3636
else:
3737
table_name = inflection.tableize(table_name)
3838
args = dict(
39-
field_name=p.name + ("_" if p.name in rust.keywords else ""),
39+
field_name=rust.avoid_keywords(p.name),
4040
base_type=_get_type(p.type),
4141
is_optional=p.is_optional,
4242
is_repeated=p.is_repeated,
@@ -86,20 +86,24 @@ def generate(opts, renderer):
8686
processor = Processor(schemaloader.load_file(opts.schema))
8787
out = opts.rust_output
8888
groups = set()
89-
for group, classes in processor.get_classes().items():
90-
group = group or "top"
91-
groups.add(group)
89+
with renderer.manage(generated=out.rglob("*.rs"),
90+
stubs=(),
91+
registry=out / ".generated.list",
92+
force=opts.force) as renderer:
93+
for group, classes in processor.get_classes().items():
94+
group = group or "top"
95+
groups.add(group)
96+
renderer.render(
97+
rust.ClassList(
98+
classes,
99+
opts.schema,
100+
),
101+
out / f"{group}.rs",
102+
)
92103
renderer.render(
93-
rust.ClassList(
94-
classes,
104+
rust.ModuleList(
105+
groups,
95106
opts.schema,
96107
),
97-
out / f"{group}.rs",
108+
out / f"mod.rs",
98109
)
99-
renderer.render(
100-
rust.ModuleList(
101-
groups,
102-
opts.schema,
103-
),
104-
out / f"mod.rs",
105-
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import dataclasses
2+
import typing
3+
import inflection
4+
5+
from misc.codegen.loaders import schemaloader
6+
from . import qlgen
7+
8+
9+
@dataclasses.dataclass
10+
class Param:
11+
name: str
12+
type: str
13+
first: bool = False
14+
15+
16+
@dataclasses.dataclass
17+
class Function:
18+
name: str
19+
signature: str
20+
21+
22+
@dataclasses.dataclass
23+
class TestCode:
24+
template: typing.ClassVar[str] = "rust_test_code"
25+
26+
code: str
27+
function: Function | None = None
28+
29+
30+
def generate(opts, renderer):
31+
assert opts.ql_test_output
32+
schema = schemaloader.load_file(opts.schema)
33+
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
34+
stubs=(),
35+
registry=opts.ql_test_output / ".generated_tests.list",
36+
force=opts.force) as renderer:
37+
for cls in schema.classes.values():
38+
if (qlgen.should_skip_qltest(cls, schema.classes) or
39+
"rust_skip_test_from_doc" in cls.pragmas or
40+
not cls.doc):
41+
continue
42+
code = []
43+
adding_code = False
44+
has_code = False
45+
for line in cls.doc:
46+
match line, adding_code:
47+
case "```", _:
48+
adding_code = not adding_code
49+
has_code = True
50+
case _, False:
51+
code.append(f"// {line}")
52+
case _, True:
53+
code.append(line)
54+
if not has_code:
55+
continue
56+
test_name = inflection.underscore(cls.name)
57+
signature = cls.rust_doc_test_function
58+
fn = signature and Function(f"test_{test_name}", signature)
59+
if fn:
60+
indent = 4 * " "
61+
code = [indent + l for l in code]
62+
test_with = schema.classes[cls.test_with] if cls.test_with else cls
63+
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
64+
renderer.render(TestCode(code="\n".join(code), function=fn), test)

misc/codegen/lib/rust.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
"try",
5858
}
5959

60+
61+
def avoid_keywords(s: str) -> str:
62+
return s + "_" if s in keywords else s
63+
64+
6065
_field_overrides = [
6166
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
6267
]
@@ -82,8 +87,7 @@ class Field:
8287
first: bool = False
8388

8489
def __post_init__(self):
85-
if self.field_name in keywords:
86-
self.field_name += "_"
90+
self.field_name = avoid_keywords(self.field_name)
8791

8892
@property
8993
def type(self) -> str:

misc/codegen/lib/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class Class:
9494
default_doc_name: Optional[str] = None
9595
hideable: bool = False
9696
test_with: Optional[str] = None
97+
rust_doc_test_function: Optional["FunctionInfo"] = None # TODO: parametrized pragmas
9798

9899
@property
99100
def final(self):

misc/codegen/lib/schemadefs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def modify(self, prop: _schema.Property):
5252
qltest = _Namespace()
5353
ql = _Namespace()
5454
cpp = _Namespace()
55+
rust = _Namespace()
5556
synth = _SynthModifier()
5657

5758

@@ -156,6 +157,10 @@ def f(cls: type) -> type:
156157

157158
_Pragma("cpp_skip")
158159

160+
_Pragma("rust_skip_doc_test")
161+
162+
rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)
163+
159164

160165
def group(name: str = "") -> _ClassDecorator:
161166
return _annotate(group=name)

misc/codegen/loaders/schemaloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _get_class(cls: type) -> schema.Class:
5656
],
5757
doc=schema.split_doc(cls.__doc__),
5858
default_doc_name=cls.__dict__.get("_doc_name"),
59+
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function")
5960
)
6061

6162

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// generated by {{generator}}
2+
3+
{{#function}}
4+
fn {{name}}{{signature}} {
5+
{{/function}}
6+
{{code}}
7+
{{#function}}
8+
}
9+
{{/function}}

0 commit comments

Comments
 (0)