Skip to content

Commit 552c524

Browse files
authored
Merge pull request #11131 from github/redsun82/swift-incomplete-ast
Swift: deal with incomplete ASTs
2 parents f0554fc + 9731048 commit 552c524

36 files changed

+1025
-224
lines changed

misc/bazel/workspace.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def codeql_workspace(repository_name = "codeql"):
2222
_swift_prebuilt_version,
2323
repo_arch,
2424
),
25+
patches = [
26+
"@%s//swift/third_party/swift-llvm-support:patches/remove_getFallthrougDest_assert.patch" % repository_name,
27+
],
28+
patch_args = ["-p1"],
2529
build_file = "@%s//swift/third_party/swift-llvm-support:BUILD.swift-prebuilt.bazel" % repository_name,
2630
sha256 = sha256,
2731
)

swift/codegen/generators/cppgen.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
"""
1313

1414
import functools
15-
import pathlib
16-
from typing import Dict
15+
import typing
1716

1817
import inflection
1918

2019
from swift.codegen.lib import cpp, schema
2120

2221

23-
def _get_type(t: str) -> str:
22+
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
2423
if t is None:
2524
# this is a predicate
2625
return "bool"
@@ -29,19 +28,23 @@ def _get_type(t: str) -> str:
2928
if t == "boolean":
3029
return "bool"
3130
if t[0].isupper():
32-
return f"TrapLabel<{t}Tag>"
31+
if add_or_none_except is not None and t != add_or_none_except:
32+
suffix = "OrNone"
33+
else:
34+
suffix = ""
35+
return f"TrapLabel<{t}{suffix}Tag>"
3336
return t
3437

3538

36-
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
39+
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
3740
trap_name = None
3841
if not p.is_single:
3942
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
4043
if not p.is_predicate:
4144
trap_name = inflection.pluralize(trap_name)
4245
args = dict(
4346
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
44-
type=_get_type(p.type),
47+
base_type=_get_type(p.type, add_or_none_except),
4548
is_optional=p.is_optional,
4649
is_repeated=p.is_repeated,
4750
is_predicate=p.is_predicate,
@@ -52,8 +55,13 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
5255

5356

5457
class Processor:
55-
def __init__(self, data: Dict[str, schema.Class]):
56-
self._classmap = data
58+
def __init__(self, data: schema.Schema):
59+
self._classmap = data.classes
60+
if data.null:
61+
root_type = next(iter(data.classes))
62+
self._add_or_none_except = root_type
63+
else:
64+
self._add_or_none_except = None
5765

5866
@functools.lru_cache(maxsize=None)
5967
def _get_class(self, name: str) -> cpp.Class:
@@ -64,7 +72,10 @@ def _get_class(self, name: str) -> cpp.Class:
6472
return cpp.Class(
6573
name=name,
6674
bases=[self._get_class(b) for b in cls.bases],
67-
fields=[_get_field(cls, p) for p in cls.properties if "cpp_skip" not in p.pragmas],
75+
fields=[
76+
_get_field(cls, p, self._add_or_none_except)
77+
for p in cls.properties if "cpp_skip" not in p.pragmas
78+
],
6879
final=not cls.derived,
6980
trap_name=trap_name,
7081
)
@@ -78,8 +89,8 @@ def get_classes(self):
7889

7990
def generate(opts, renderer):
8091
assert opts.cpp_output
81-
processor = Processor(schema.load_file(opts.schema).classes)
92+
processor = Processor(schema.load_file(opts.schema))
8293
out = opts.cpp_output
8394
for dir, classes in processor.get_classes().items():
8495
renderer.render(cpp.ClassList(classes, opts.schema,
85-
include_parent=bool(dir)), out / dir / "TrapClasses")
96+
include_parent=bool(dir)), out / dir / "TrapClasses")

swift/codegen/generators/dbschemegen.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
as columns
1414
The type hierarchy will be translated to corresponding `union` declarations.
1515
"""
16+
import typing
1617

1718
import inflection
1819

@@ -23,14 +24,21 @@
2324
log = logging.getLogger(__name__)
2425

2526

26-
def dbtype(typename):
27-
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes """
27+
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
28+
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
29+
For class types, appends an underscore followed by `null` if provided
30+
"""
2831
if typename[0].isupper():
29-
return "@" + inflection.underscore(typename)
32+
underscored = inflection.underscore(typename)
33+
if add_or_none_except is not None and typename != add_or_none_except:
34+
suffix = "_or_none"
35+
else:
36+
suffix = ""
37+
return f"@{underscored}{suffix}"
3038
return typename
3139

3240

33-
def cls_to_dbscheme(cls: schema.Class):
41+
def cls_to_dbscheme(cls: schema.Class, add_or_none_except: typing.Optional[str] = None):
3442
""" Yield all dbscheme entities needed to model class `cls` """
3543
if cls.derived:
3644
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
@@ -48,7 +56,7 @@ def cls_to_dbscheme(cls: schema.Class):
4856
columns=[
4957
Column("id", type=dbtype(cls.name), binding=binding),
5058
] + [
51-
Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
59+
Column(f.name, dbtype(f.type, add_or_none_except)) for f in cls.properties if f.is_single
5260
],
5361
dir=dir,
5462
)
@@ -61,7 +69,7 @@ def cls_to_dbscheme(cls: schema.Class):
6169
columns=[
6270
Column("id", type=dbtype(cls.name)),
6371
Column("index", type="int"),
64-
Column(inflection.singularize(f.name), dbtype(f.type)),
72+
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
6573
],
6674
dir=dir,
6775
)
@@ -71,7 +79,7 @@ def cls_to_dbscheme(cls: schema.Class):
7179
name=inflection.tableize(f"{cls.name}_{f.name}"),
7280
columns=[
7381
Column("id", type=dbtype(cls.name)),
74-
Column(f.name, dbtype(f.type)),
82+
Column(f.name, dbtype(f.type, add_or_none_except)),
7583
],
7684
dir=dir,
7785
)
@@ -87,7 +95,17 @@ def cls_to_dbscheme(cls: schema.Class):
8795

8896

8997
def get_declarations(data: schema.Schema):
90-
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
98+
add_or_none_except = data.root_class.name if data.null else None
99+
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, add_or_none_except)]
100+
if data.null:
101+
property_classes = {
102+
prop.type for cls in data.classes.values() for prop in cls.properties
103+
if cls.name != data.null and prop.type and prop.type[0].isupper()
104+
}
105+
declarations += [
106+
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
107+
]
108+
return declarations
91109

92110

93111
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):

swift/codegen/generators/qlgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str =
147147
return ql.Property(**args)
148148

149149

150-
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
150+
def get_ql_class(cls: schema.Class):
151151
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
152152
prev_child = ""
153153
properties = []
@@ -314,7 +314,7 @@ def generate(opts, renderer):
314314

315315
data = schema.load_file(input)
316316

317-
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
317+
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
318318
if not classes:
319319
raise NoClasses
320320
root = next(iter(classes.values()))

swift/codegen/generators/trapgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def get_cpp_type(schema_type: str):
4141
def get_field(c: dbscheme.Column):
4242
args = {
4343
"field_name": c.schema_name,
44-
"type": c.type,
44+
"base_type": c.type,
4545
}
4646
args.update(cpp.get_field_override(c.schema_name))
47-
args["type"] = get_cpp_type(args["type"])
47+
args["base_type"] = get_cpp_type(args["base_type"])
4848
return cpp.Field(**args)
4949

5050

swift/codegen/lib/cpp.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"xor", "xor_eq"}
1717

1818
_field_overrides = [
19-
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"type": "unsigned"}),
19+
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
2020
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
2121
]
2222

@@ -32,21 +32,26 @@ def get_field_override(field: str):
3232
@dataclass
3333
class Field:
3434
field_name: str
35-
type: str
35+
base_type: str
3636
is_optional: bool = False
3737
is_repeated: bool = False
3838
is_predicate: bool = False
3939
trap_name: str = None
4040
first: bool = False
4141

4242
def __post_init__(self):
43-
if self.is_optional:
44-
self.type = f"std::optional<{self.type}>"
45-
if self.is_repeated:
46-
self.type = f"std::vector<{self.type}>"
4743
if self.field_name in cpp_keywords:
4844
self.field_name += "_"
4945

46+
@property
47+
def type(self) -> str:
48+
type = self.base_type
49+
if self.is_optional:
50+
type = f"std::optional<{type}>"
51+
if self.is_repeated:
52+
type = f"std::vector<{type}>"
53+
return type
54+
5055
# using @property breaks pystache internals here
5156
def get_streamer(self):
5257
if self.type == "std::string":
@@ -60,6 +65,10 @@ def get_streamer(self):
6065
def is_single(self):
6166
return not (self.is_optional or self.is_repeated or self.is_predicate)
6267

68+
@property
69+
def is_label(self):
70+
return self.base_type.startswith("TrapLabel<")
71+
6372

6473
@dataclass
6574
class Trap:

swift/codegen/lib/schema/defs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def f(cls: type) -> type:
115115
doc = _DocModifier
116116
desc = _DescModifier
117117

118+
use_for_null = _annotate(null=True)
119+
118120
qltest = _Namespace(
119121
skip=_Pragma("qltest_skip"),
120122
collapse_hierarchy=_Pragma("qltest_collapse_hierarchy"),

swift/codegen/lib/schema/schema.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def is_repeated(self) -> bool:
5555
def is_predicate(self) -> bool:
5656
return self.kind == self.Kind.PREDICATE
5757

58+
@property
59+
def has_class_type(self) -> bool:
60+
return bool(self.type) and self.type[0].isupper()
61+
62+
@property
63+
def has_builtin_type(self) -> bool:
64+
return bool(self.type) and self.type[0].islower()
65+
5866

5967
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
6068
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
@@ -104,6 +112,16 @@ def check_types(self, known: typing.Iterable[str]):
104112
class Schema:
105113
classes: Dict[str, Class] = field(default_factory=dict)
106114
includes: Set[str] = field(default_factory=set)
115+
null: Optional[str] = None
116+
117+
@property
118+
def root_class(self):
119+
# always the first in the dictionary
120+
return next(iter(self.classes.values()))
121+
122+
@property
123+
def null_class(self):
124+
return self.classes[self.null] if self.null else None
107125

108126

109127
predicate_marker = object()
@@ -195,6 +213,8 @@ def _get_class(cls: type) -> Class:
195213
raise Error(f"Class name must be capitalized, found {cls.__name__}")
196214
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
197215
raise Error(f"Bases with mixed groups for {cls.__name__}")
216+
if any(getattr(b, "_null", False) for b in cls.__bases__):
217+
raise Error(f"Null class cannot be derived")
198218
return Class(name=cls.__name__,
199219
bases=[b.__name__ for b in cls.__bases__ if b is not object],
200220
derived={d.__name__ for d in cls.__subclasses__()},
@@ -233,6 +253,7 @@ def load(m: types.ModuleType) -> Schema:
233253
known = {"int", "string", "boolean"}
234254
known.update(n for n in m.__dict__ if not n.startswith("__"))
235255
import swift.codegen.lib.schema.defs as defs
256+
null = None
236257
for name, data in m.__dict__.items():
237258
if hasattr(defs, name):
238259
continue
@@ -247,8 +268,13 @@ def load(m: types.ModuleType) -> Schema:
247268
f"Only one root class allowed, found second root {name}")
248269
cls.check_types(known)
249270
classes[name] = cls
271+
if getattr(data, "_null", False):
272+
if null is not None:
273+
raise Error(f"Null class {null} already defined, second null class {name} not allowed")
274+
null = name
275+
cls.is_null_class = True
250276

251-
return Schema(includes=includes, classes=_toposort_classes_by_group(classes))
277+
return Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
252278

253279

254280
def load_file(path: pathlib.Path) -> Schema:

swift/codegen/templates/cpp_classes_h.mustache

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace codeql {
1717
{{#classes}}
1818

1919
struct {{name}}{{#has_bases}} : {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}} {
20+
static constexpr const char* NAME = "{{name}}";
21+
2022
{{#final}}
2123
explicit {{name}}(TrapLabel<{{name}}Tag> id) : id{id} {}
2224

@@ -33,6 +35,41 @@ struct {{name}}{{#has_bases}} : {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/b
3335
}
3436
{{/final}}
3537

38+
{{^final}}
39+
protected:
40+
{{/final}}
41+
template <typename F>
42+
void forEachLabel(F f) {
43+
{{#final}}
44+
f("id", -1, id);
45+
{{/final}}
46+
{{#bases}}
47+
{{ref.name}}::forEachLabel(f);
48+
{{/bases}}
49+
{{#fields}}
50+
{{#is_label}}
51+
{{#is_repeated}}
52+
for (auto i = 0u; i < {{field_name}}.size(); ++i) {
53+
{{#is_optional}}
54+
if ({{field_name}}[i]) f("{{field_name}}", i, *{{field_name}}[i]);
55+
{{/is_optional}}
56+
{{^is_optional}}
57+
f("{{field_name}}", i, {{field_name}}[i]);
58+
{{/is_optional}}
59+
}
60+
{{/is_repeated}}
61+
{{^is_repeated}}
62+
{{#is_optional}}
63+
if ({{field_name}}) f("{{field_name}}", -1, *{{field_name}});
64+
{{/is_optional}}
65+
{{^is_optional}}
66+
f("{{field_name}}", -1, {{field_name}});
67+
{{/is_optional}}
68+
{{/is_repeated}}
69+
{{/is_label}}
70+
{{/fields}}
71+
}
72+
3673
protected:
3774
void emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const;
3875
};

0 commit comments

Comments
 (0)