Skip to content

Commit a470d66

Browse files
optimize, fix topological sort.
1 parent 38d8813 commit a470d66

File tree

6 files changed

+17
-20
lines changed

6 files changed

+17
-20
lines changed

datajoint/dependencies.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def load(self, force=True):
106106
raise DataJointError("DataJoint can only work with acyclic dependencies")
107107
self._loaded = True
108108

109+
def topo_sort(self):
110+
""":return: list of nodes in lexcigraphical topological order"""
111+
return list(nx.algorithms.dag.lexicographical_topological_sort(self))
112+
109113
def parents(self, table_name, primary=None):
110114
"""
111115
:param table_name: `schema`.`table`
@@ -142,18 +146,14 @@ def descendants(self, full_table_name):
142146
:return: all dependent tables sorted in topological order. Self is included.
143147
"""
144148
self.load(force=False)
145-
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name))
146-
return [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes))
149+
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)).copy()
150+
return [full_table_name] + nodes.topo_sort()
147151

148152
def ancestors(self, full_table_name):
149153
"""
150154
:param full_table_name: In form `schema`.`table_name`
151155
:return: all dependent tables sorted in topological order. Self is included.
152156
"""
153157
self.load(force=False)
154-
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name))
155-
return list(
156-
reversed(
157-
list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name]
158-
)
159-
)
158+
nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)).copy()
159+
return reversed(nodes.topo_sort() + [full_table_name])

datajoint/diagram.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,6 @@ def is_part(part, master):
177177
)
178178
return self
179179

180-
def topological_sort(self):
181-
""":return: list of nodes in lexcigraphical topological order"""
182-
return list(
183-
nx.algorithms.dag.lexicographical_topological_sort(
184-
nx.DiGraph(self).subgraph(self.nodes_to_show)
185-
)
186-
)
187-
188180
def __add__(self, arg):
189181
"""
190182
:param arg: either another Diagram or a positive integer.

datajoint/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def replace(s):
453453

454454
diagram = Diagram(self)
455455
body = "\n\n".join(
456-
make_class_definition(table) for table in diagram.topological_sort()
456+
make_class_definition(table) for table in diagram.topo_sort()
457457
)
458458
python_code = "\n\n".join(
459459
(
@@ -484,7 +484,7 @@ def list_tables(self):
484484
t
485485
for d, t in (
486486
full_t.replace("`", "").split(".")
487-
for full_t in Diagram(self).topological_sort()
487+
for full_t in Diagram(self).topo_sort()
488488
)
489489
if d == self.database
490490
]

datajoint/table.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False):
217217

218218
def descendants(self, as_objects=False):
219219
"""
220-
221220
:param as_objects: False - a list of table names; True - a list of table objects.
222221
:return: list of tables descendants in topological order.
223222
"""

tests/test_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import json
6-
import ast
76
import subprocess
87
import pytest
98
import datajoint as dj

tests/test_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def test_list_tables(schema_simp):
217217
actual = set(schema_simp.list_tables())
218218
assert actual == expected, f"Missing from list_tables(): {expected - actual}"
219219

220+
def test_schema_save_any(schema_any):
221+
assert "class Experiment(dj.Imported)" in schema_any.code
222+
223+
224+
def test_schema_save_empty(schema_empty):
225+
assert "class Experiment(dj.Imported)" in schema_empty.code
226+
220227

221228
def test_uppercase_schema(db_creds_root):
222229
"""

0 commit comments

Comments
 (0)