Skip to content

Commit ec9bb1d

Browse files
author
Paolo Tranquilli
committed
Codegen: allow to include .py files in schema.py
1 parent 4a9e3ee commit ec9bb1d

File tree

3 files changed

+58
-5
lines changed

3 files changed

+58
-5
lines changed

misc/codegen/lib/schemadefs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@ def modify(self, prop: _schema.Property):
3232

3333

3434
def include(source: str):
35-
# add to `includes` variable in calling context
36-
_inspect.currentframe().f_back.f_locals.setdefault(
37-
"__includes", []).append(source)
35+
scope = _inspect.currentframe().f_back.f_locals
36+
if source.endswith(".dbscheme"):
37+
# add to `includes` variable in calling context
38+
scope.setdefault("__includes", []).append(source)
39+
elif source.endswith(".py"):
40+
# just load the contents
41+
with open(source) as input:
42+
exec(input.read(), scope)
43+
else:
44+
raise _schema.Error(f"Unsupported file for inclusion: {source}")
3845

3946

4047
class _Namespace:

misc/codegen/loaders/schemaloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
126126

127127

128128
def load(m: types.ModuleType) -> schema.Schema:
129-
includes = set()
129+
includes = []
130130
classes = {}
131131
known = {"int", "string", "boolean"}
132132
known.update(n for n in m.__dict__ if not n.startswith("__"))

misc/codegen/test/test_schemaloader.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class data:
1313
pass
1414

1515
assert data.classes == {}
16-
assert data.includes == set()
16+
assert data.includes == []
1717
assert data.null is None
1818
assert data.null_class is None
1919

@@ -805,5 +805,51 @@ class C(Root):
805805
pass
806806

807807

808+
def test_include_dbscheme():
809+
@load
810+
class data:
811+
defs.include("foo.dbscheme")
812+
defs.include("bar.dbscheme")
813+
814+
assert data.includes == ["foo.dbscheme", "bar.dbscheme"]
815+
816+
817+
def test_include_source(tmp_path):
818+
(tmp_path / "foo.py").write_text("""
819+
class A(Root):
820+
pass
821+
""")
822+
(tmp_path / "bar.py").write_text("""
823+
class C(Root):
824+
pass
825+
""")
826+
827+
@load
828+
class data:
829+
class Root:
830+
pass
831+
832+
defs.include(str(tmp_path / "foo.py"))
833+
834+
class B(Root):
835+
pass
836+
837+
defs.include(str(tmp_path / "bar.py"))
838+
839+
assert data.classes == {
840+
"Root": schema.Class("Root", derived=set("ABC")),
841+
"A": schema.Class("A", bases=["Root"]),
842+
"B": schema.Class("B", bases=["Root"]),
843+
"C": schema.Class("C", bases=["Root"]),
844+
}
845+
846+
847+
def test_include_not_supported(tmp_path):
848+
with pytest.raises(schema.Error):
849+
@load
850+
class data:
851+
defs.include("foo.bar")
852+
853+
808854
if __name__ == '__main__':
809855
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)