Skip to content

Commit 24c090d

Browse files
debugged topological sort
1 parent adfdc65 commit 24c090d

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

datajoint/dependencies.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from collections import defaultdict
55
from .errors import DataJointError
66

7+
78
def extract_master(part_table):
89
"""
9-
given a part table name, return master part. None if not a part table
10+
given a part table name, return master part. None if not a part table
1011
"""
1112
match = re.match(r"(?P<master>`\w+`.`#?\w+)__\w+`", part_table)
12-
return match['master'] + '`' if match else None
13-
13+
return match["master"] + "`" if match else None
1414

1515

1616
def topo_sort(graph):
@@ -39,22 +39,19 @@ def topo_sort(graph):
3939
# to ensure correct topological ordering of the masters.
4040
for part in graph:
4141
# find the part's master
42-
master = extract_master(part)
43-
if master:
42+
if (master := extract_master(part)) in graph:
4443
for edge in graph.in_edges(part):
4544
parent = edge[0]
4645
if parent != master and extract_master(parent) != master:
4746
graph.add_edge(parent, master)
48-
4947
sorted_nodes = list(nx.topological_sort(graph))
5048

5149
# bring parts up to their masters
5250
pos = len(sorted_nodes) - 1
5351
placed = set()
5452
while pos > 1:
5553
part = sorted_nodes[pos]
56-
master = extract_master(part)
57-
if not master or part in placed:
54+
if not (master := extract_master) or part in placed:
5855
pos -= 1
5956
else:
6057
placed.add(part)
@@ -63,7 +60,7 @@ def topo_sort(graph):
6360
except ValueError:
6461
# master not found
6562
pass
66-
else:
63+
else:
6764
if pos > j + 1:
6865
# move the part to its master
6966
del sorted_nodes[pos]
@@ -214,8 +211,8 @@ def descendants(self, full_table_name):
214211
:return: all dependent tables sorted in topological order. Self is included.
215212
"""
216213
self.load(force=False)
217-
nodes = self.subgraph(nx.descendants(self, full_table_name))
218-
return [full_table_name] + nodes.topo_sort()
214+
nodes = self.subgraph(nx.descendants(self, full_table_name))
215+
return [full_table_name] + nodes.topo_sort()
219216

220217
def ancestors(self, full_table_name):
221218
"""

tests/test_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def test_list_tables(schema_simp):
217217
actual = set(schema_simp.list_tables())
218218
assert actual == expected, f"Missing from list_tables(): {expected - actual}"
219219

220+
220221
def test_schema_save_any(schema_any):
221222
assert "class Experiment(dj.Imported)" in schema_any.code
222223

0 commit comments

Comments
 (0)