Skip to content

Commit fe72dfe

Browse files
authored
Merge pull request github#9028 from redsun82/swift-trapgen
Swift: add `trapgen` unit tests
2 parents 9b855c3 + 10c5c8e commit fe72dfe

File tree

11 files changed

+372
-52
lines changed

11 files changed

+372
-52
lines changed

swift/codegen/BUILD.bazel

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
load("@swift_codegen_deps//:requirements.bzl", "requirement")
2+
13
py_binary(
24
name = "codegen",
3-
srcs = glob(["*.py"]),
5+
srcs = glob(
6+
["*.py"],
7+
exclude = ["trapgen.py"],
8+
),
49
visibility = ["//swift/codegen/test:__pkg__"],
510
deps = ["//swift/codegen/lib"],
611
)
@@ -12,5 +17,8 @@ py_binary(
1217
srcs = ["trapgen.py"],
1318
data = ["//swift/codegen/templates:cpp"],
1419
visibility = ["//swift:__subpackages__"],
15-
deps = ["//swift/codegen/lib"],
20+
deps = [
21+
"//swift/codegen/lib",
22+
requirement("toposort"),
23+
],
1624
)

swift/codegen/lib/cpp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ class Field:
2121
type: str
2222
first: bool = False
2323

24+
@property
2425
def cpp_name(self):
2526
if self.name in cpp_keywords:
2627
return self.name + "_"
2728
return self.name
2829

29-
def stream(self):
30+
# using @property breaks pystache internals here
31+
def get_streamer(self):
3032
if self.type == "std::string":
3133
return lambda x: f"trapQuoted({x})"
3234
elif self.type == "bool":
@@ -65,6 +67,7 @@ def __post_init__(self):
6567
self.bases = [TagBase(b) for b in self.bases]
6668
self.bases[0].first = True
6769

70+
@property
6871
def has_bases(self):
6972
return bool(self.bases)
7073

swift/codegen/lib/dbscheme.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,10 @@ def get_union(match):
144144

145145

146146
def iterload(file):
147-
data = Re.comment.sub("", file.read())
147+
with open(file) as file:
148+
data = Re.comment.sub("", file.read())
148149
for e in Re.entity.finditer(data):
149150
if e["table"]:
150151
yield get_table(e)
151152
elif e["union"]:
152153
yield get_union(e)
153-
154-
155-
def load(file):
156-
return list(iterload(file))

swift/codegen/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
pystache
2-
pyyaml
31
inflection
2+
pystache
43
pytest
4+
pyyaml
5+
toposort

swift/codegen/templates/cpp_traps.mustache

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct {{name}}Trap {
2525

2626
inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
2727
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
28-
<< {{#stream}}e.{{cpp_name}}{{/stream}}{{/fields}} << ")";
28+
<< {{#get_streamer}}e.{{cpp_name}}{{/get_streamer}}{{/fields}} << ")";
2929
return out;
3030
}
3131
{{/traps}}

swift/codegen/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ py_library(
1818
deps = [
1919
":utils",
2020
"//swift/codegen",
21+
"//swift/codegen:trapgen",
2122
],
2223
)
2324
for src in glob(["test_*.py"])

swift/codegen/test/test_cpp.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import sys
2+
from copy import deepcopy
3+
4+
import pytest
5+
6+
from swift.codegen.lib import cpp
7+
8+
9+
@pytest.mark.parametrize("keyword", cpp.cpp_keywords)
10+
def test_field_keyword_cpp_name(keyword):
11+
f = cpp.Field(keyword, "int")
12+
assert f.cpp_name == keyword + "_"
13+
14+
15+
def test_field_cpp_name():
16+
f = cpp.Field("foo", "int")
17+
assert f.cpp_name == "foo"
18+
19+
20+
@pytest.mark.parametrize("type,expected", [
21+
("std::string", "trapQuoted(value)"),
22+
("bool", '(value ? "true" : "false")'),
23+
("something_else", "value"),
24+
])
25+
def test_field_get_streamer(type, expected):
26+
f = cpp.Field("name", type)
27+
assert f.get_streamer()("value") == expected
28+
29+
30+
def test_trap_has_first_field_marked():
31+
fields = [
32+
cpp.Field("a", "x"),
33+
cpp.Field("b", "y"),
34+
cpp.Field("c", "z"),
35+
]
36+
expected = deepcopy(fields)
37+
expected[0].first = True
38+
t = cpp.Trap("table_name", "name", fields)
39+
assert t.fields == expected
40+
41+
42+
def test_tag_has_first_base_marked():
43+
bases = ["a", "b", "c"]
44+
expected = [cpp.TagBase("a", first=True), cpp.TagBase("b"), cpp.TagBase("c")]
45+
t = cpp.Tag("name", bases, 0, "id")
46+
assert t.bases == expected
47+
48+
49+
@pytest.mark.parametrize("bases,expected", [
50+
([], False),
51+
(["a"], True),
52+
(["a", "b"], True)
53+
])
54+
def test_tag_has_bases(bases, expected):
55+
t = cpp.Tag("name", bases, 0, "id")
56+
assert t.has_bases is expected
57+
58+
59+
if __name__ == '__main__':
60+
sys.exit(pytest.main())

swift/codegen/test/test_dbscheme.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,107 @@ def test_union_has_first_case_marked():
4848
assert [c.type for c in u.rhs] == rhs
4949

5050

51+
# load tests
52+
@pytest.fixture
53+
def load(tmp_path):
54+
file = tmp_path / "test.dbscheme"
55+
56+
def ret(yml):
57+
write(file, yml)
58+
return list(dbscheme.iterload(file))
59+
60+
return ret
61+
62+
63+
def test_load_empty(load):
64+
assert load("") == []
65+
66+
67+
def test_load_one_empty_table(load):
68+
assert load("""
69+
test_foos();
70+
""") == [
71+
dbscheme.Table(name="test_foos", columns=[])
72+
]
73+
74+
75+
def test_load_table_with_keyset(load):
76+
assert load("""
77+
#keyset[x, y,z]
78+
test_foos();
79+
""") == [
80+
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
81+
]
82+
83+
84+
expected_columns = [
85+
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
86+
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
87+
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
88+
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
89+
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
90+
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
91+
]
92+
93+
94+
@pytest.mark.parametrize("column,expected", expected_columns)
95+
def test_load_table_with_column(load, column, expected):
96+
assert load(f"""
97+
foos(
98+
{column}
99+
);
100+
""") == [
101+
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
102+
]
103+
104+
105+
def test_load_table_with_multiple_columns(load):
106+
columns = ",\n".join(c for c, _ in expected_columns)
107+
expected = [deepcopy(e) for _, e in expected_columns]
108+
assert load(f"""
109+
foos(
110+
{columns}
111+
);
112+
""") == [
113+
dbscheme.Table(name="foos", columns=expected)
114+
]
115+
116+
117+
def test_load_multiple_table_with_columns(load):
118+
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
119+
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
120+
assert load("\n".join(tables)) == expected
121+
122+
123+
def test_union(load):
124+
assert load("@foo = @bar | @baz | @bla;") == [
125+
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
126+
]
127+
128+
129+
def test_table_and_union(load):
130+
assert load("""
131+
foos();
132+
133+
@foo = @bar | @baz | @bla;""") == [
134+
dbscheme.Table(name="foos", columns=[]),
135+
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
136+
]
137+
138+
139+
def test_comments_ignored(load):
140+
assert load("""
141+
// fake_table();
142+
foos(/* x */unique /*y*/int/*
143+
z
144+
*/ id/* */: /* * */ @bar/*,
145+
int ignored: int ref*/);
146+
147+
@foo = @bar | @baz | @bla; // | @xxx""") == [
148+
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
149+
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
150+
]
151+
152+
51153
if __name__ == '__main__':
52154
sys.exit(pytest.main())

0 commit comments

Comments
 (0)