Skip to content

Commit 80c7dc8

Browse files
committed
use topological sort for node stream update
1 parent f77da4b commit 80c7dc8

File tree

10 files changed

+397
-231
lines changed

10 files changed

+397
-231
lines changed

NodeGraphQt/base/graph.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(self, parent=None):
148148

149149
self._wire_signals()
150150
self._node_space_bar = node_space_bar(self)
151+
self._auto_update = True
151152

152153
def __repr__(self):
153154
return '<{} object at {}>'.format(self.__class__.__name__, hex(id(self)))
@@ -376,6 +377,15 @@ def widget(self):
376377
layout.addWidget(self._viewer)
377378
return self._widget
378379

380+
@property
381+
def auto_update(self):
382+
"""
383+
384+
Returns:
385+
if the graph can run node automatically.
386+
"""
387+
return self._auto_update
388+
379389
def show(self):
380390
"""
381391
Show node graph widget this is just a convenience
@@ -1251,7 +1261,8 @@ def import_session(self, file_path):
12511261
Args:
12521262
file_path (str): path to the serialized layout file.
12531263
"""
1254-
1264+
_temp_auto_update = self._auto_update
1265+
self._auto_update = False
12551266
file_path = file_path.strip()
12561267
if not os.path.isfile(file_path):
12571268
raise IOError('file does not exist.')
@@ -1277,6 +1288,7 @@ def import_session(self, file_path):
12771288
self._undo_stack.clear()
12781289
self._model.session = file_path
12791290
self.session_changed.emit(file_path)
1291+
self._auto_update = _temp_auto_update
12801292

12811293
def copy_nodes(self, nodes=None):
12821294
"""

NodeGraphQt/base/node.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,13 +1183,13 @@ class SubGraph(object):
11831183
"""
11841184

11851185
def __init__(self):
1186-
self._children = []
1186+
self._children = set()
11871187

11881188
def children(self):
11891189
"""
11901190
Returns the children of the sub graph.
11911191
"""
1192-
return self._children
1192+
return list(self._children)
11931193

11941194
def create_from_nodes(self, nodes):
11951195
"""
@@ -1207,8 +1207,7 @@ def add_child(self, node):
12071207
Args:
12081208
node(NodeGraphQt.BaseNode).
12091209
"""
1210-
if node not in self._children:
1211-
self._children.append(node)
1210+
self._children.add(node)
12121211

12131212
def remove_child(self, node):
12141213
"""

NodeGraphQt/base/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
PIPE_LAYOUT_ANGLE)
99

1010

11+
# menu
1112
def setup_context_menu(graph):
1213
"""
1314
populate the specified graph's context menu with essential menus commands.
@@ -275,3 +276,87 @@ def _angle_pipe(graph):
275276

276277
def _toggle_grid(graph):
277278
graph.display_grid(not graph.scene().grid)
279+
280+
281+
# topological_sort
282+
283+
def get_input_nodes(node):
284+
nodes = {}
285+
for p in node.input_ports():
286+
for cp in p.connected_ports():
287+
n = cp.node()
288+
nodes[n.id] = n
289+
return list(nodes.values())
290+
291+
292+
def get_output_nodes(node):
293+
nodes = {}
294+
for p in node.output_ports():
295+
for cp in p.connected_ports():
296+
n = cp.node()
297+
nodes[n.id] = n
298+
return list(nodes.values())
299+
300+
301+
def _has_input_node(node):
302+
for p in node.input_ports():
303+
if p.view.connected_pipes:
304+
return True
305+
return False
306+
307+
308+
def _has_output_node(node):
309+
for p in node.output_ports():
310+
if p.view.connected_pipes:
311+
return True
312+
return False
313+
314+
315+
def _build_graph(start_nodes):
316+
graph = {}
317+
for node in start_nodes:
318+
output_nodes = get_output_nodes(node)
319+
graph[node] = output_nodes
320+
while output_nodes:
321+
_output_nodes = []
322+
for n in output_nodes:
323+
if n not in graph:
324+
nodes = get_output_nodes(n)
325+
graph[n] = nodes
326+
_output_nodes.extend(nodes)
327+
output_nodes = _output_nodes
328+
329+
return graph
330+
331+
332+
def topological_sort(start_nodes=[], all_nodes=[]):
333+
if not start_nodes:
334+
start_nodes = [n for n in all_nodes if not _has_input_node(n)]
335+
if not start_nodes:
336+
return []
337+
if not [n for n in start_nodes if _has_output_node(n)]:
338+
return start_nodes
339+
340+
graph = _build_graph(start_nodes)
341+
if not graph:
342+
return []
343+
344+
visit = dict((node, False) for node in graph.keys())
345+
346+
sorted_nodes = []
347+
348+
def dfs(graph, start_node):
349+
for end_node in graph[start_node]:
350+
if not visit[end_node]:
351+
visit[end_node] = True
352+
dfs(graph, end_node)
353+
sorted_nodes.append(start_node)
354+
355+
for start_node in start_nodes:
356+
if not visit[start_node]:
357+
visit[start_node] = True
358+
dfs(graph, start_node)
359+
360+
sorted_nodes.reverse()
361+
362+
return sorted_nodes

NodeGraphQt/widgets/viewer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,8 @@ def establish_connection(self, start_port, end_port):
672672
if not start_port.node.visible or not end_port.node.visible:
673673
pipe.hide()
674674

675-
def acyclic_check(self, start_port, end_port):
675+
@staticmethod
676+
def acyclic_check(start_port, end_port):
676677
"""
677678
validate the connection so it doesn't loop itself.
678679
@@ -724,11 +725,13 @@ def context_menus(self):
724725
return {'graph': self._ctx_menu,
725726
'nodes': self._ctx_node_menu}
726727

727-
def question_dialog(self, text, title='Node Graph'):
728+
@staticmethod
729+
def question_dialog(text, title='Node Graph'):
728730
dlg = messageBox(text, title, QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No)
729731
return dlg == QtWidgets.QMessageBox.Yes
730732

731-
def message_dialog(self, text, title='Node Graph'):
733+
@staticmethod
734+
def message_dialog(text, title='Node Graph'):
732735
messageBox(text, title, QtWidgets.QMessageBox.Ok)
733736

734737
def load_dialog(self, current_dir=None, ext=None):
@@ -798,7 +801,8 @@ def add_node(self, node, pos=None):
798801
self.scene().addItem(node)
799802
node.post_init(self, pos)
800803

801-
def remove_node(self, node):
804+
@staticmethod
805+
def remove_node(node):
802806
if isinstance(node, AbstractNodeItem):
803807
node.delete()
804808

example_auto_nodes.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from NodeGraphQt import NodeGraph, setup_context_menu
66
from NodeGraphQt import QtWidgets, QtCore, PropertiesBinWidget, \
77
NodeTreeWidget, BackdropNode, NodePublishWidget
8+
from NodeGraphQt.base.utils import topological_sort
89
import os
910
import sys
1011
import inspect
@@ -51,7 +52,7 @@ def get_published_nodes_from_folder(folder_path):
5152

5253

5354
def cook_node(graph, node):
54-
node.cook()
55+
node.cook(forceCook=True)
5556

5657

5758
def print_functions(graph, node):
@@ -85,6 +86,14 @@ def publish_node(graph, node):
8586
wid.show()
8687

8788

89+
def cook_nodes(nodes):
90+
nodes = topological_sort(all_nodes=nodes)
91+
for node in nodes:
92+
node.cook(stream=True)
93+
if node.error():
94+
break
95+
96+
8897
if __name__ == '__main__':
8998
app = QtWidgets.QApplication([])
9099

@@ -136,9 +145,7 @@ def show_nodes_list(node):
136145

137146
# create test nodes
138147
graph.load_session(r'example_auto_nodes/networks/example_SubGraph.json')
139-
graph.get_node_by_path('/root/Distance').cook()
140-
graph.get_node_by_path('/root/Cross Product').cook()
141-
graph.get_node_by_path('/root/Dot Product').cook()
148+
cook_nodes(graph.root_node().children())
142149

143150
# widget used for the node graph.
144151
graph_widget = graph.widget

0 commit comments

Comments
 (0)