Skip to content

Commit f73bb59

Browse files
Merge branch 'master' into key_populate
2 parents fbeaec9 + 24c090d commit f73bb59

File tree

7 files changed

+197
-50
lines changed

7 files changed

+197
-50
lines changed

datajoint/dependencies.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,74 @@
11
import networkx as nx
22
import itertools
3+
import re
34
from collections import defaultdict
45
from .errors import DataJointError
56

67

8+
def extract_master(part_table):
9+
"""
10+
given a part table name, return master part. None if not a part table
11+
"""
12+
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
13+
return match["master"] + "`" if match else None
14+
15+
16+
def topo_sort(graph):
17+
"""
18+
topological sort of a dependency graph that keeps part tables together with their masters
19+
:return: list of table names in topological order
20+
"""
21+
22+
graph = nx.DiGraph(graph) # make a copy
23+
24+
# collapse alias nodes
25+
alias_nodes = [node for node in graph if node.isdigit()]
26+
for node in alias_nodes:
27+
try:
28+
direct_edge = (
29+
next(x for x in graph.in_edges(node))[0],
30+
next(x for x in graph.out_edges(node))[1],
31+
)
32+
except StopIteration:
33+
pass # a disconnected alias node
34+
else:
35+
graph.add_edge(*direct_edge)
36+
graph.remove_nodes_from(alias_nodes)
37+
38+
# Add parts' dependencies to their masters' dependencies
39+
# to ensure correct topological ordering of the masters.
40+
for part in graph:
41+
# find the part's master
42+
if (master := extract_master(part)) in graph:
43+
for edge in graph.in_edges(part):
44+
parent = edge[0]
45+
if parent != master and extract_master(parent) != master:
46+
graph.add_edge(parent, master)
47+
sorted_nodes = list(nx.topological_sort(graph))
48+
49+
# bring parts up to their masters
50+
pos = len(sorted_nodes) - 1
51+
placed = set()
52+
while pos > 1:
53+
part = sorted_nodes[pos]
54+
if not (master := extract_master) or part in placed:
55+
pos -= 1
56+
else:
57+
placed.add(part)
58+
try:
59+
j = sorted_nodes.index(master)
60+
except ValueError:
61+
# master not found
62+
pass
63+
else:
64+
if pos > j + 1:
65+
# move the part to its master
66+
del sorted_nodes[pos]
67+
sorted_nodes.insert(j + 1, part)
68+
69+
return sorted_nodes
70+
71+
772
class Dependencies(nx.DiGraph):
873
"""
974
The graph of dependencies (foreign keys) between loaded tables.
@@ -106,6 +171,10 @@ def load(self, force=True):
106171
raise DataJointError("DataJoint can only work with acyclic dependencies")
107172
self._loaded = True
108173

174+
def topo_sort(self):
175+
""":return: list of tables names in topological order"""
176+
return topo_sort(self)
177+
109178
def parents(self, table_name, primary=None):
110179
"""
111180
:param table_name: `schema`.`table`
@@ -142,18 +211,14 @@ def descendants(self, full_table_name):
142211
:return: all dependent tables sorted in topological order. Self is included.
143212
"""
144213
self.load(force=False)
145-
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
146-
return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
214+
nodes = self.subgraph(nx.descendants(self, full_table_name))
215+
return [full_table_name] + nodes.topo_sort()
147216

148217
def ancestors(self, full_table_name):
149218
"""
150219
:param full_table_name: In form `schema`.`table_name`
151220
:return: all dependent tables sorted in topological order. Self is included.
152221
"""
153222
self.load(force=False)
154-
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
155-
return list(
156-
reversed(
157-
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
158-
)
159-
)
223+
nodes = self.subgraph(nx.ancestors(self, full_table_name))
224+
return reversed(nodes.topo_sort() + [full_table_name])

datajoint/diagram.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import logging
66
import inspect
77
from .table import Table
8-
from .user_tables import Manual, Imported, Computed, Lookup, Part
8+
from .dependencies import topo_sort
9+
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
910
from .errors import DataJointError
1011
from .table import lookup_class_name
1112

@@ -26,29 +27,6 @@
2627

2728

2829
logger = logging.getLogger(__name__.split(".")[0])
29-
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
30-
31-
32-
class _AliasNode:
33-
"""
34-
special class to indicate aliased foreign keys
35-
"""
36-
37-
pass
38-
39-
40-
def _get_tier(table_name):
41-
if not table_name.startswith("`"):
42-
return _AliasNode
43-
else:
44-
try:
45-
return next(
46-
tier
47-
for tier in user_table_classes
48-
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
49-
)
50-
except StopIteration:
51-
return None
5230

5331

5432
if not diagram_active:
@@ -70,19 +48,22 @@ def __init__(self, *args, **kwargs):
7048

7149
class Diagram(nx.DiGraph):
7250
"""
73-
Entity relationship diagram.
51+
Schema diagram showing tables and foreign keys between in the form of a directed
52+
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
7453
7554
Usage:
7655
7756
>>> diag = Diagram(source)
7857
79-
source can be a base table object, a base table class, a schema, or a module that has a schema.
58+
source can be a table object, a table class, a schema, or a module that has a schema.
8059
8160
>>> diag.draw()
8261
8362
draws the diagram using pyplot
8463
8564
diag1 + diag2 - combines the two diagrams.
65+
diag1 - diag2 - differente between diagrams
66+
diag1 * diag2 - intersction of diagrams
8667
diag + n - expands n levels of successors
8768
diag - n - expands n levels of predecessors
8869
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
@@ -91,7 +72,8 @@ class Diagram(nx.DiGraph):
9172
Only those tables that are loaded in the connection object are displayed
9273
"""
9374

94-
def __init__(self, source, context=None):
75+
def __init__(self, source=None, context=None):
76+
9577
if isinstance(source, Diagram):
9678
# copy constructor
9779
self.nodes_to_show = set(source.nodes_to_show)
@@ -152,7 +134,7 @@ def from_sequence(cls, sequence):
152134

153135
def add_parts(self):
154136
"""
155-
Adds to the diagram the part tables of tables already included in the diagram
137+
Adds to the diagram the part tables of all master tables already in the diagram
156138
:return:
157139
"""
158140

@@ -177,14 +159,6 @@ def is_part(part, master):
177159
)
178160
return self
179161

180-
def topological_sort(self):
181-
""":return: list of nodes in topological order"""
182-
return list(
183-
nx.algorithms.dag.topological_sort(
184-
nx.DiGraph(self).subgraph(self.nodes_to_show)
185-
)
186-
)
187-
188162
def __add__(self, arg):
189163
"""
190164
:param arg: either another Diagram or a positive integer.
@@ -252,6 +226,10 @@ def __mul__(self, arg):
252226
self.nodes_to_show.intersection_update(arg.nodes_to_show)
253227
return self
254228

229+
def topo_sort(self):
230+
"""return nodes in lexicographical topological order"""
231+
return topo_sort(self)
232+
255233
def _make_graph(self):
256234
"""
257235
Make the self.graph - a graph object ready for drawing

datajoint/schemas.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import logging
33
import inspect
44
import re
5+
import collections
6+
import itertools
57
from .connection import conn
6-
from .diagram import Diagram
78
from .settings import config
89
from .errors import DataJointError, AccessError
910
from .jobs import JobTable
1011
from .external import ExternalMapping
1112
from .heading import Heading
1213
from .utils import user_choice, to_camel_case
13-
from .user_tables import Part, Computed, Imported, Manual, Lookup
14-
from .table import lookup_class_name, Log
14+
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
15+
from .table import lookup_class_name, Log, FreeTable
1516
import types
1617

1718
logger = logging.getLogger(__name__.split(".")[0])
@@ -399,6 +400,76 @@ def jobs(self):
399400
self._jobs = JobTable(self.connection, self.database)
400401
return self._jobs
401402

403+
@property
404+
def code(self):
405+
self._assert_exists()
406+
return self.save()
407+
408+
def save(self, python_filename=None):
409+
"""
410+
Generate the code for a module that recreates the schema.
411+
This method is in preparation for a future release and is not officially supported.
412+
413+
:return: a string containing the body of a complete Python module defining this schema.
414+
"""
415+
self._assert_exists()
416+
module_count = itertools.count()
417+
# add virtual modules for referenced modules with names vmod0, vmod1, ...
418+
module_lookup = collections.defaultdict(
419+
lambda: "vmod" + str(next(module_count))
420+
)
421+
db = self.database
422+
423+
def make_class_definition(table):
424+
tier = _get_tier(table).__name__
425+
class_name = table.split(".")[1].strip("`")
426+
indent = ""
427+
if tier == "Part":
428+
class_name = class_name.split("__")[-1]
429+
indent += " "
430+
class_name = to_camel_case(class_name)
431+
432+
def replace(s):
433+
d, tabs = s.group(1), s.group(2)
434+
return ("" if d == db else (module_lookup[d] + ".")) + ".".join(
435+
to_camel_case(tab) for tab in tabs.lstrip("__").split("__")
436+
)
437+
438+
return ("" if tier == "Part" else "\n@schema\n") + (
439+
"{indent}class {class_name}(dj.{tier}):\n"
440+
'{indent} definition = """\n'
441+
'{indent} {defi}"""'
442+
).format(
443+
class_name=class_name,
444+
indent=indent,
445+
tier=tier,
446+
defi=re.sub(
447+
r"`([^`]+)`.`([^`]+)`",
448+
replace,
449+
FreeTable(self.connection, table).describe(),
450+
).replace("\n", "\n " + indent),
451+
)
452+
453+
tables = self.connection.dependencies.topo_sort()
454+
body = "\n\n".join(make_class_definition(table) for table in tables)
455+
python_code = "\n\n".join(
456+
(
457+
'"""This module was auto-generated by datajoint from an existing schema"""',
458+
"import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db),
459+
"\n".join(
460+
"{module} = dj.VirtualModule('{module}', '{schema_name}')".format(
461+
module=v, schema_name=k
462+
)
463+
for k, v in module_lookup.items()
464+
),
465+
body,
466+
)
467+
)
468+
if python_filename is None:
469+
return python_code
470+
with open(python_filename, "wt") as f:
471+
f.write(python_code)
472+
402473
def list_tables(self):
403474
"""
404475
Return a list of all tables in the schema except tables with ~ in first character such
@@ -410,7 +481,7 @@ def list_tables(self):
410481
t
411482
for d, t in (
412483
full_t.replace("`", "").split(".")
413-
for full_t in Diagram(self).topological_sort()
484+
for full_t in Diagram(self).topo_sort()
414485
)
415486
if d == self.database
416487
]

datajoint/table.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False):
217217

218218
def descendants(self, as_objects=False):
219219
"""
220-
221220
:param as_objects: False - a list of table names; True - a list of table objects.
222221
:return: list of tables descendants in topological order.
223222
"""

datajoint/user_tables.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Hosts the table tiers, user tables should be derived from.
33
"""
44

5+
import re
56
from .table import Table
67
from .autopopulate import AutoPopulate
78
from .utils import from_camel_case, ClassProperty
@@ -242,3 +243,29 @@ def drop(self, force=False):
242243
def alter(self, prompt=True, context=None):
243244
# without context, use declaration context which maps master keyword to master table
244245
super().alter(prompt=prompt, context=context or self.declaration_context)
246+
247+
248+
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
249+
250+
251+
class _AliasNode:
252+
"""
253+
special class to indicate aliased foreign keys
254+
"""
255+
256+
pass
257+
258+
259+
def _get_tier(table_name):
260+
"""given the table name, return"""
261+
if not table_name.startswith("`"):
262+
return _AliasNode
263+
else:
264+
try:
265+
return next(
266+
tier
267+
for tier in user_table_classes
268+
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
269+
)
270+
except StopIteration:
271+
return None

tests/test_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import json
6-
import ast
76
import subprocess
87
import pytest
98
import datajoint as dj

tests/test_schema.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ def test_list_tables(schema_simp):
218218
assert actual == expected, f"Missing from list_tables(): {expected - actual}"
219219

220220

221+
def test_schema_save_any(schema_any):
222+
assert "class Experiment(dj.Imported)" in schema_any.code
223+
224+
225+
def test_schema_save_empty(schema_empty):
226+
assert "class Experiment(dj.Imported)" in schema_empty.code
227+
228+
221229
def test_uppercase_schema(db_creds_root):
222230
"""
223231
https://github.com/datajoint/datajoint-python/issues/564

0 commit comments

Comments
 (0)