Skip to content

Commit 3cd8aaf

Browse files
author
Paolo Tranquilli
committed
Rust: simplify rust doc test annotation
1 parent 928f3f1 commit 3cd8aaf

File tree

6 files changed

+20
-37
lines changed

6 files changed

+20
-37
lines changed

misc/codegen/generators/rustgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
3636
else:
3737
table_name = inflection.tableize(table_name)
3838
args = dict(
39-
field_name=p.name + ("_" if p.name in rust.keywords else ""),
39+
field_name=rust.avoid_keywords(p.name),
4040
base_type=_get_type(p.type),
4141
is_optional=p.is_optional,
4242
is_repeated=p.is_repeated,
Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import typing
3+
import inflection
34

45
from misc.codegen.loaders import schemaloader
56
from . import qlgen
@@ -15,19 +16,7 @@ class Param:
1516
@dataclasses.dataclass
1617
class Function:
1718
name: str
18-
generic_params: list[Param]
19-
params: list[Param]
20-
return_type: str
21-
22-
def __post_init__(self):
23-
if self.generic_params:
24-
self.generic_params[0].first = True
25-
if self.params:
26-
self.params[0].first = True
27-
28-
@property
29-
def has_generic_params(self) -> bool:
30-
return bool(self.generic_params)
19+
signature: str
3120

3221

3322
@dataclasses.dataclass
@@ -48,27 +37,28 @@ def generate(opts, renderer):
4837
for cls in schema.classes.values():
4938
if (qlgen.should_skip_qltest(cls, schema.classes) or
5039
"rust_skip_test_from_doc" in cls.pragmas or
51-
not cls.doc
52-
):
40+
not cls.doc):
5341
continue
54-
fn = cls.rust_doc_test_function
55-
if fn:
56-
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
57-
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
58-
fn = Function(fn.name, generic_params, params, fn.return_type)
5942
code = []
6043
adding_code = False
44+
has_code = False
6145
for line in cls.doc:
6246
match line, adding_code:
6347
case "```", _:
6448
adding_code = not adding_code
49+
has_code = True
6550
case _, False:
6651
code.append(f"// {line}")
6752
case _, True:
6853
code.append(line)
54+
if not has_code:
55+
continue
56+
test_name = inflection.underscore(cls.name)
57+
signature = cls.rust_doc_test_function
58+
fn = signature and Function(f"test_{test_name}", signature)
6959
if fn:
7060
indent = 4 * " "
7161
code = [indent + l for l in code]
7262
test_with = schema.classes[cls.test_with] if cls.test_with else cls
73-
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
63+
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
7464
renderer.render(TestCode(code="\n".join(code), function=fn), test)

misc/codegen/lib/rust.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
"try",
5858
}
5959

60+
61+
def avoid_keywords(s: str) -> str:
62+
return s + "_" if s in keywords else s
63+
64+
6065
_field_overrides = [
6166
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
6267
]
@@ -82,8 +87,7 @@ class Field:
8287
first: bool = False
8388

8489
def __post_init__(self):
85-
if self.field_name in keywords:
86-
self.field_name += "_"
90+
self.field_name = avoid_keywords(self.field_name)
8791

8892
@property
8993
def type(self) -> str:

misc/codegen/lib/schema.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,3 @@ def split_doc(doc):
203203
while trimmed and not trimmed[0]:
204204
trimmed.pop(0)
205205
return trimmed
206-
207-
208-
@dataclass
209-
class FunctionInfo:
210-
name: str
211-
params: dict[str, str]
212-
return_type: str

misc/codegen/lib/schemadefs.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,7 @@ def f(cls: type) -> type:
159159

160160
_Pragma("rust_skip_doc_test")
161161

162-
rust.doc_test_function = lambda name, *, lifetimes=(), return_type="()", **kwargs: _annotate(
163-
rust_doc_test_function=_schema.FunctionInfo(name,
164-
params={f"'{lifetime}": "" for lifetime in lifetimes} | kwargs,
165-
return_type=return_type)
166-
)
162+
rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)
167163

168164

169165
def group(name: str = "") -> _ClassDecorator:

misc/codegen/templates/rust_test_code.mustache

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// generated by {{generator}}
22

33
{{#function}}
4-
fn {{name}}{{#has_generic_params}}<{{#generic_params}}{{^first}}, {{/first}}{{name}}{{#type}}: {{.}}{{/type}}{{/generic_params}}>{{/has_generic_params}}({{#params}}{{^first}}, {{/first}}{{name}}: {{type}}{{/params}}) -> {{return_type}} {
4+
fn {{name}}{{signature}} {
55
{{/function}}
66
{{code}}
77
{{#function}}

0 commit comments

Comments
 (0)