Skip to content

Commit adfdc65

Browse files
fix topological sort
1 parent 4be8e39 commit adfdc65

File tree

4 files changed

+73
-63
lines changed

4 files changed

+73
-63
lines changed

datajoint/dependencies.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,70 @@
44
from collections import defaultdict
55
from .errors import DataJointError
66

7+
def extract_master(part_table):
8+
"""
9+
given a part table name, return master part. None if not a part table
10+
"""
11+
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
12+
return match['master'] + '`' if match else None
13+
14+
715

816
def topo_sort(graph):
917
"""
1018
topological sort of a dependency graph that keeps part tables together with their masters
1119
:return: list of table names in topological order
1220
"""
21+
1322
graph = nx.DiGraph(graph) # make a copy
1423

1524
# collapse alias nodes
1625
alias_nodes = [node for node in graph if node.isdigit()]
1726
for node in alias_nodes:
18-
direct_edge = (
19-
next(x for x in graph.in_edges(node))[0],
20-
next(x for x in graph.out_edges(node))[1],
21-
)
22-
graph.add_edge(*direct_edge)
27+
try:
28+
direct_edge = (
29+
next(x for x in graph.in_edges(node))[0],
30+
next(x for x in graph.out_edges(node))[1],
31+
)
32+
except StopIteration:
33+
pass # a disconnected alias node
34+
else:
35+
graph.add_edge(*direct_edge)
2336
graph.remove_nodes_from(alias_nodes)
2437

2538
# Add parts' dependencies to their masters' dependencies
2639
# to ensure correct topological ordering of the masters.
27-
part_pattern = re.compile(r"(?P<master>`\w+`.`#?\w+)__\w+`")
2840
for part in graph:
29-
# print part tables and their master
30-
match = part_pattern.match(part)
31-
if match:
32-
master = match["master"] + "`"
41+
# find the part's master
42+
master = extract_master(part)
43+
if master:
3344
for edge in graph.in_edges(part):
34-
if edge[0] != master:
35-
graph.add_edge(edge[0], master)
45+
parent = edge[0]
46+
if parent != master and extract_master(parent) != master:
47+
graph.add_edge(parent, master)
3648

37-
sorted_nodes = list(nx.algorithms.topological_sort(graph))
49+
sorted_nodes = list(nx.topological_sort(graph))
3850

3951
# bring parts up to their masters
40-
pos = len(sorted_nodes)
41-
while pos > 0:
42-
pos -= 1
52+
pos = len(sorted_nodes) - 1
53+
placed = set()
54+
while pos > 1:
4355
part = sorted_nodes[pos]
44-
match = part_pattern.match(part)
45-
if match:
46-
master = match["master"] + "`"
47-
print(part, master)
56+
master = extract_master(part)
57+
if not master or part in placed:
58+
pos -= 1
59+
else:
60+
placed.add(part)
4861
try:
4962
j = sorted_nodes.index(master)
5063
except ValueError:
5164
# master not found
52-
continue
53-
if pos > j + 1:
54-
print(pos, j)
55-
# move the part to its master
56-
del sorted_nodes[pos]
57-
sorted_nodes.insert(j + 1, part)
58-
pos += 1
65+
pass
66+
else:
67+
if pos > j + 1:
68+
# move the part to its master
69+
del sorted_nodes[pos]
70+
sorted_nodes.insert(j + 1, part)
5971

6072
return sorted_nodes
6173

@@ -202,16 +214,14 @@ def descendants(self, full_table_name):
202214
:return: all dependent tables sorted in topological order. Self is included.
203215
"""
204216
self.load(force=False)
205-
nodes = self.subgraph(
206-
nx.algorithms.dag.descendants(self, full_table_name)
207-
).copy()
208-
return [full_table_name] + nodes.topo_sort()
217+
nodes = self.subgraph(nx.descendants(self, full_table_name))
218+
return [full_table_name] + nodes.topo_sort()
209219

210220
def ancestors(self, full_table_name):
211221
"""
212222
:param full_table_name: In form `schema`.`table_name`
213223
:return: all dependent tables sorted in topological order. Self is included.
214224
"""
215225
self.load(force=False)
216-
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)).copy()
226+
nodes = self.subgraph(nx.ancestors(self, full_table_name))
217227
return reversed(nodes.topo_sort() + [full_table_name])

datajoint/diagram.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
from .table import Table
88
from .dependencies import topo_sort
9-
from .user_tables import Manual, Imported, Computed, Lookup, Part
9+
from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode
1010
from .errors import DataJointError
1111
from .table import lookup_class_name
1212

@@ -27,30 +27,6 @@
2727

2828

2929
logger = logging.getLogger(__name__.split(".")[0])
30-
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
31-
32-
33-
class _AliasNode:
34-
"""
35-
special class to indicate aliased foreign keys
36-
"""
37-
38-
pass
39-
40-
41-
def _get_tier(table_name):
42-
"""given the table name, return"""
43-
if not table_name.startswith("`"):
44-
return _AliasNode
45-
else:
46-
try:
47-
return next(
48-
tier
49-
for tier in user_table_classes
50-
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
51-
)
52-
except StopIteration:
53-
return None
5430

5531

5632
if not diagram_active:

datajoint/schemas.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import collections
66
import itertools
77
from .connection import conn
8-
from .diagram import Diagram, _get_tier
98
from .settings import config
109
from .errors import DataJointError, AccessError
1110
from .jobs import JobTable
1211
from .external import ExternalMapping
1312
from .heading import Heading
1413
from .utils import user_choice, to_camel_case
15-
from .user_tables import Part, Computed, Imported, Manual, Lookup
14+
from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier
1615
from .table import lookup_class_name, Log, FreeTable
1716
import types
1817

@@ -451,10 +450,8 @@ def replace(s):
451450
).replace("\n", "\n " + indent),
452451
)
453452

454-
diagram = Diagram(self)
455-
body = "\n\n".join(
456-
make_class_definition(table) for table in diagram.topo_sort()
457-
)
453+
tables = self.connection.dependencies.topo_sort()
454+
body = "\n\n".join(make_class_definition(table) for table in tables)
458455
python_code = "\n\n".join(
459456
(
460457
'"""This module was auto-generated by datajoint from an existing schema"""',

datajoint/user_tables.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Hosts the table tiers, user tables should be derived from.
33
"""
44

5+
import re
56
from .table import Table
67
from .autopopulate import AutoPopulate
78
from .utils import from_camel_case, ClassProperty
@@ -242,3 +243,29 @@ def drop(self, force=False):
242243
def alter(self, prompt=True, context=None):
243244
# without context, use declaration context which maps master keyword to master table
244245
super().alter(prompt=prompt, context=context or self.declaration_context)
246+
247+
248+
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
249+
250+
251+
class _AliasNode:
252+
"""
253+
special class to indicate aliased foreign keys
254+
"""
255+
256+
pass
257+
258+
259+
def _get_tier(table_name):
260+
"""given the table name, return"""
261+
if not table_name.startswith("`"):
262+
return _AliasNode
263+
else:
264+
try:
265+
return next(
266+
tier
267+
for tier in user_table_classes
268+
if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
269+
)
270+
except StopIteration:
271+
return None

0 commit comments

Comments
 (0)