Skip to content

Commit 4be8e39

Browse files
fix topological sort
1 parent a470d66 commit 4be8e39

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

datajoint/dependencies.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,65 @@
11
import networkx as nx
22
import itertools
3+
import re
34
from collections import defaultdict
45
from .errors import DataJointError
56

67

8+
def topo_sort(graph):
9+
"""
10+
topological sort of a dependency graph that keeps part tables together with their masters
11+
:return: list of table names in topological order
12+
"""
13+
graph = nx.DiGraph(graph) # make a copy
14+
15+
# collapse alias nodes
16+
alias_nodes = [node for node in graph if node.isdigit()]
17+
for node in alias_nodes:
18+
direct_edge = (
19+
next(x for x in graph.in_edges(node))[0],
20+
next(x for x in graph.out_edges(node))[1],
21+
)
22+
graph.add_edge(*direct_edge)
23+
graph.remove_nodes_from(alias_nodes)
24+
25+
# Add parts' dependencies to their masters' dependencies
26+
# to ensure correct topological ordering of the masters.
27+
part_pattern = re.compile(r"(?P<master>`\w+`.`#?\w+)__\w+`")
28+
for part in graph:
29+
# print part tables and their master
30+
match = part_pattern.match(part)
31+
if match:
32+
master = match["master"] + "`"
33+
for edge in graph.in_edges(part):
34+
if edge[0] != master:
35+
graph.add_edge(edge[0], master)
36+
37+
sorted_nodes = list(nx.algorithms.topological_sort(graph))
38+
39+
# bring parts up to their masters
40+
pos = len(sorted_nodes)
41+
while pos > 0:
42+
pos -= 1
43+
part = sorted_nodes[pos]
44+
match = part_pattern.match(part)
45+
if match:
46+
master = match["master"] + "`"
47+
print(part, master)
48+
try:
49+
j = sorted_nodes.index(master)
50+
except ValueError:
51+
# master not found
52+
continue
53+
if pos > j + 1:
54+
print(pos, j)
55+
# move the part to its master
56+
del sorted_nodes[pos]
57+
sorted_nodes.insert(j + 1, part)
58+
pos += 1
59+
60+
return sorted_nodes
61+
62+
763
class Dependencies(nx.DiGraph):
864
"""
965
The graph of dependencies (foreign keys) between loaded tables.
@@ -107,8 +163,8 @@ def load(self, force=True):
107163
self._loaded = True
108164

109165
def topo_sort(self):
110-
""":return: list of nodes in lexcigraphical topological order"""
111-
return list(nx.algorithms.dag.lexicographical_topological_sort(self))
166+
""":return: list of tables names in topological order"""
167+
return topo_sort(self)
112168

113169
def parents(self, table_name, primary=None):
114170
"""
@@ -146,7 +202,9 @@ def descendants(self, full_table_name):
146202
:return: all dependent tables sorted in topological order. Self is included.
147203
"""
148204
self.load(force=False)
149-
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)).copy()
205+
nodes = self.subgraph(
206+
nx.algorithms.dag.descendants(self, full_table_name)
207+
).copy()
150208
return [full_table_name] + nodes.topo_sort()
151209

152210
def ancestors(self, full_table_name):

datajoint/diagram.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import inspect
77
from .table import Table
8+
from .dependencies import topo_sort
89
from .user_tables import Manual, Imported, Computed, Lookup, Part
910
from .errors import DataJointError
1011
from .table import lookup_class_name
@@ -38,6 +39,7 @@ class _AliasNode:
3839

3940

4041
def _get_tier(table_name):
42+
"""given the table name, return"""
4143
if not table_name.startswith("`"):
4244
return _AliasNode
4345
else:
@@ -70,19 +72,22 @@ def __init__(self, *args, **kwargs):
7072

7173
class Diagram(nx.DiGraph):
7274
"""
73-
Entity relationship diagram.
75+
Schema diagram showing tables and foreign keys between in the form of a directed
76+
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
7477
7578
Usage:
7679
7780
>>> diag = Diagram(source)
7881
79-
source can be a base table object, a base table class, a schema, or a module that has a schema.
82+
source can be a table object, a table class, a schema, or a module that has a schema.
8083
8184
>>> diag.draw()
8285
8386
draws the diagram using pyplot
8487
8588
diag1 + diag2 - combines the two diagrams.
89+
diag1 - diag2 - differente between diagrams
90+
diag1 * diag2 - intersction of diagrams
8691
diag + n - expands n levels of successors
8792
diag - n - expands n levels of predecessors
8893
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
@@ -91,7 +96,8 @@ class Diagram(nx.DiGraph):
9196
Only those tables that are loaded in the connection object are displayed
9297
"""
9398

94-
def __init__(self, source, context=None):
99+
def __init__(self, source=None, context=None):
100+
95101
if isinstance(source, Diagram):
96102
# copy constructor
97103
self.nodes_to_show = set(source.nodes_to_show)
@@ -152,7 +158,7 @@ def from_sequence(cls, sequence):
152158

153159
def add_parts(self):
154160
"""
155-
Adds to the diagram the part tables of tables already included in the diagram
161+
Adds to the diagram the part tables of all master tables already in the diagram
156162
:return:
157163
"""
158164

@@ -244,6 +250,10 @@ def __mul__(self, arg):
244250
self.nodes_to_show.intersection_update(arg.nodes_to_show)
245251
return self
246252

253+
def topo_sort(self):
254+
"""return nodes in lexicographical topological order"""
255+
return topo_sort(self)
256+
247257
def _make_graph(self):
248258
"""
249259
Make the self.graph - a graph object ready for drawing

0 commit comments

Comments
 (0)