Skip to content

Commit b57a374

Browse files
author
Paolo Tranquilli
committed
Rust: make File usable in codegen
1 parent 7e0e5a3 commit b57a374

40 files changed

+363
-141
lines changed

misc/codegen/generators/dbschemegen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
110110

111111
def get_declarations(data: schema.Schema):
112112
add_or_none_except = data.root_class.name if data.null else None
113-
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)]
113+
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
114+
data.classes, add_or_none_except)]
114115
if data.null:
115116
property_classes = {
116117
prop.type for cls in data.classes.values() for prop in cls.properties

misc/codegen/generators/qlgen.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
104104
return f"{prop_name} of this {class_name}"
105105

106106

107-
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.Class],
107+
def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> bool:
108+
if t in lookup:
109+
match lookup[t]:
110+
case schema.Class() as cls:
111+
return "ql_hideable" in cls.pragmas
112+
return False
113+
114+
115+
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
108116
prev_child: str = "") -> ql.Property:
117+
109118
args = dict(
110119
type=prop.type if not prop.is_predicate else "predicate",
111120
qltest_skip="qltest_skip" in prop.pragmas,
@@ -115,7 +124,8 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
115124
is_unordered=prop.is_unordered,
116125
description=prop.description,
117126
synth=bool(cls.synth) or prop.synth,
118-
type_is_hideable="ql_hideable" in lookup[prop.type].pragmas if prop.type in lookup else False,
127+
type_is_hideable=_type_is_hideable(prop.type, lookup),
128+
type_is_codegen_class=prop.type in lookup and not lookup[prop.type].imported,
119129
internal="ql_internal" in prop.pragmas,
120130
)
121131
ql_name = prop.pragmas.get("ql_name", prop.name)
@@ -154,7 +164,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
154164
return ql.Property(**args)
155165

156166

157-
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> ql.Class:
167+
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
158168
if "ql_name" in cls.pragmas:
159169
raise Error("ql_name is not supported yet for classes, only for properties")
160170
prev_child = ""
@@ -391,14 +401,15 @@ def generate(opts, renderer):
391401

392402
data = schemaloader.load_file(input)
393403

394-
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
404+
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
395405
if not classes:
396406
raise NoClasses
397407
root = next(iter(classes.values()))
398408
if root.has_children:
399409
raise RootElementHasChildren(root)
400410

401-
imports = {}
411+
pre_imports = {n: cls.module for n, cls in data.classes.items() if cls.imported}
412+
imports = dict(pre_imports)
402413
imports_impl = {}
403414
classes_used_by = {}
404415
cfg_classes = []
@@ -410,7 +421,7 @@ def generate(opts, renderer):
410421
force=opts.force) as renderer:
411422

412423
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
413-
renderer.render(ql.DbClasses(db_classes), out / "Raw.qll")
424+
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")
414425

415426
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
416427
for c in classes_by_dir_and_name:
@@ -439,6 +450,8 @@ def generate(opts, renderer):
439450
renderer.render(cfg_classes_val, cfg_qll)
440451

441452
for c in data.classes.values():
453+
if c.imported:
454+
continue
442455
path = _get_path(c)
443456
path_impl = _get_path_impl(c)
444457
stub_file = stub_out / path_impl
@@ -457,20 +470,23 @@ def generate(opts, renderer):
457470
renderer.render(class_public, class_public_file)
458471

459472
# for example path/to/elements -> path/to/elements.qll
460-
renderer.render(ql.ImportList([i for name, i in imports.items() if not classes[name].internal]),
473+
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
461474
include_file)
462475

463476
elements_module = get_import(include_file, opts.root_dir)
464477

465478
renderer.render(
466479
ql.GetParentImplementation(
467480
classes=list(classes.values()),
468-
imports=[elements_module] + [i for name, i in imports.items() if classes[name].internal],
481+
imports=[elements_module] + [i for name,
482+
i in imports.items() if name in classes and classes[name].internal],
469483
),
470484
out / 'ParentChild.qll')
471485

472486
if test_out:
473487
for c in data.classes.values():
488+
if c.imported:
489+
continue
474490
if should_skip_qltest(c, data.classes):
475491
continue
476492
test_with_name = c.pragmas.get("qltest_test_with")
@@ -500,7 +516,8 @@ def generate(opts, renderer):
500516
constructor_imports = []
501517
synth_constructor_imports = []
502518
stubs = {}
503-
for cls in sorted(data.classes.values(), key=lambda cls: (cls.group, cls.name)):
519+
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
520+
key=lambda cls: (cls.group, cls.name)):
504521
synth_type = get_ql_synth_class(cls)
505522
if synth_type.is_final:
506523
final_synth_types.append(synth_type)

misc/codegen/generators/rustgen.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
4949

5050

5151
def _get_properties(
52-
cls: schema.Class, lookup: dict[str, schema.Class],
52+
cls: schema.Class, lookup: dict[str, schema.ClassBase],
5353
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
5454
for b in cls.bases:
5555
yield from _get_properties(lookup[b], lookup)
@@ -58,20 +58,22 @@ def _get_properties(
5858

5959

6060
def _get_ancestors(
61-
cls: schema.Class, lookup: dict[str, schema.Class]
61+
cls: schema.Class, lookup: dict[str, schema.ClassBase]
6262
) -> typing.Iterable[schema.Class]:
6363
for b in cls.bases:
6464
base = lookup[b]
65-
yield base
66-
yield from _get_ancestors(base, lookup)
65+
if not base.imported:
66+
base = typing.cast(schema.Class, base)
67+
yield base
68+
yield from _get_ancestors(base, lookup)
6769

6870

6971
class Processor:
7072
def __init__(self, data: schema.Schema):
7173
self._classmap = data.classes
7274

7375
def _get_class(self, name: str) -> rust.Class:
74-
cls = self._classmap[name]
76+
cls = typing.cast(schema.Class, self._classmap[name])
7577
properties = [
7678
(c, p)
7779
for c, p in _get_properties(cls, self._classmap)
@@ -101,8 +103,10 @@ def _get_class(self, name: str) -> rust.Class:
101103
def get_classes(self):
102104
ret = {"": []}
103105
for k, cls in self._classmap.items():
104-
if not cls.synth:
106+
if not cls.imported and not cls.synth:
105107
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
108+
elif cls.imported:
109+
ret[""].append(rust.Class(name=cls.name))
106110
return ret
107111

108112

misc/codegen/generators/rusttestgen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def generate(opts, renderer):
5656
registry=opts.ql_test_output / ".generated_tests.list",
5757
force=opts.force) as renderer:
5858
for cls in schema.classes.values():
59+
if cls.imported:
60+
continue
5961
if (qlgen.should_skip_qltest(cls, schema.classes) or
6062
"rust_skip_doc_test" in cls.pragmas):
6163
continue

misc/codegen/lib/ql.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Property:
4444
doc_plural: Optional[str] = None
4545
synth: bool = False
4646
type_is_hideable: bool = False
47+
type_is_codegen_class: bool = False
4748
internal: bool = False
4849
cfg: bool = False
4950

@@ -66,10 +67,6 @@ def indefinite_getter(self):
6667
article = "An" if self.singular[0] in "AEIO" else "A"
6768
return f"get{article}{self.singular}"
6869

69-
@property
70-
def type_is_class(self):
71-
return bool(self.type) and self.type[0].isupper()
72-
7370
@property
7471
def is_repeated(self):
7572
return bool(self.plural)
@@ -191,6 +188,7 @@ class DbClasses:
191188
template: ClassVar = 'ql_db'
192189

193190
classes: List[Class] = field(default_factory=list)
191+
imports: List[str] = field(default_factory=list)
194192

195193

196194
@dataclass

misc/codegen/lib/schema.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44
from collections.abc import Iterable
55
from dataclasses import dataclass, field
6-
from typing import List, Set, Union, Dict, Optional
6+
from typing import List, Set, Union, Dict, Optional, FrozenSet
77
from enum import Enum, auto
88
import functools
99

@@ -87,8 +87,22 @@ class SynthInfo:
8787

8888

8989
@dataclass
90-
class Class:
90+
class ClassBase:
91+
imported: typing.ClassVar[bool]
9192
name: str
93+
94+
95+
@dataclass
96+
class ImportedClass(ClassBase):
97+
imported: typing.ClassVar[bool] = True
98+
99+
module: str
100+
101+
102+
@dataclass
103+
class Class(ClassBase):
104+
imported: typing.ClassVar[bool] = False
105+
92106
bases: List[str] = field(default_factory=list)
93107
derived: Set[str] = field(default_factory=set)
94108
properties: List[Property] = field(default_factory=list)
@@ -133,7 +147,7 @@ def group(self) -> str:
133147

134148
@dataclass
135149
class Schema:
136-
classes: Dict[str, Class] = field(default_factory=dict)
150+
classes: Dict[str, ClassBase] = field(default_factory=dict)
137151
includes: List[str] = field(default_factory=list)
138152
null: Optional[str] = None
139153

@@ -155,7 +169,7 @@ def iter_properties(self, cls: str) -> Iterable[Property]:
155169

156170
predicate_marker = object()
157171

158-
TypeRef = Union[type, str]
172+
TypeRef = type | str | ImportedClass
159173

160174

161175
def get_type_name(arg: TypeRef) -> str:
@@ -164,6 +178,8 @@ def get_type_name(arg: TypeRef) -> str:
164178
return arg.__name__
165179
case str():
166180
return arg
181+
case ImportedClass():
182+
return arg.name
167183
case _:
168184
raise Error(f"Not a schema type or string ({arg})")
169185

@@ -172,9 +188,9 @@ def _make_property(arg: object) -> Property:
172188
match arg:
173189
case _ if arg is predicate_marker:
174190
return PredicateProperty()
175-
case str() | type():
191+
case (str() | type() | ImportedClass()) as arg:
176192
return SingleProperty(type=get_type_name(arg))
177-
case Property():
193+
case Property() as arg:
178194
return arg
179195
case _:
180196
raise Error(f"Illegal property specifier {arg}")

misc/codegen/lib/schemadefs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import inspect as _inspect
99
from dataclasses import dataclass as _dataclass
1010

11-
from misc.codegen.lib.schema import Property
12-
1311
_set = set
1412

1513

@@ -69,6 +67,9 @@ def include(source: str):
6967
_inspect.currentframe().f_back.f_locals.setdefault("includes", []).append(source)
7068

7169

70+
imported = _schema.ImportedClass
71+
72+
7273
@_dataclass
7374
class _Namespace:
7475
""" simple namespacing mechanism """
@@ -264,7 +265,7 @@ class _PropertyModifierList(_schema.PropertyModifier):
264265
def __or__(self, other: _schema.PropertyModifier):
265266
return _PropertyModifierList(self._mods + (other,))
266267

267-
def modify(self, prop: Property):
268+
def modify(self, prop: _schema.Property):
268269
for m in self._mods:
269270
m.modify(prop)
270271

misc/codegen/loaders/schemaloader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
132132
def load(m: types.ModuleType) -> schema.Schema:
133133
includes = set()
134134
classes = {}
135+
imported_classes = {}
135136
known = {"int", "string", "boolean"}
136137
known.update(n for n in m.__dict__ if not n.startswith("__"))
137138
import misc.codegen.lib.schemadefs as defs
@@ -146,6 +147,9 @@ def load(m: types.ModuleType) -> schema.Schema:
146147
continue
147148
if isinstance(data, types.ModuleType):
148149
continue
150+
if isinstance(data, schema.ImportedClass):
151+
imported_classes[name] = data
152+
continue
149153
cls = _get_class(data)
150154
if classes and not cls.bases:
151155
raise schema.Error(
@@ -162,7 +166,7 @@ def load(m: types.ModuleType) -> schema.Schema:
162166
_fill_hideable_information(classes)
163167
_check_test_with(classes)
164168

165-
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
169+
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)
166170

167171

168172
def load_file(path: pathlib.Path) -> schema.Schema:

misc/codegen/templates/ql_class.mustache

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ module Generated {
113113
*/
114114
{{type}} {{getter}}({{#is_indexed}}int index{{/is_indexed}}) {
115115
{{^synth}}
116-
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_class}}Synth::convert{{type}}FromRaw({{/type_is_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_class}}){{/type_is_class}}
116+
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_codegen_class}}Synth::convert{{type}}FromRaw({{/type_is_codegen_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_codegen_class}}){{/type_is_codegen_class}}
117117
{{/synth}}
118118
{{#synth}}
119119
none()

misc/codegen/templates/ql_db.mustache

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
* This module holds thin fully generated class definitions around DB entities.
44
*/
55
module Raw {
6+
{{#imports}}
7+
private import {{.}}
8+
{{/imports}}
9+
610
{{#classes}}
711
/**
812
* INTERNAL: Do not use.

0 commit comments

Comments
 (0)