Skip to content

Commit 389571e

Browse files
committed
add node graph auto layout
1 parent f2867a4 commit 389571e

File tree

2 files changed

+185
-9
lines changed

2 files changed

+185
-9
lines changed

NodeGraphQt/base/utils.py

Lines changed: 177 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,16 @@ def setup_context_menu(graph):
7676
edit_menu.add_command('Deselect all', _clear_node_selection, 'Ctrl+Shift+A')
7777
edit_menu.add_command('Enable/Disable', _disable_nodes, 'D')
7878

79-
edit_menu.add_command('Duplicate', _duplicate_nodes, 'Alt+c')
79+
edit_menu.add_command('Duplicate', _duplicate_nodes, 'Alt+C')
8080
edit_menu.add_command('Center Selection', _fit_to_selection, 'F')
8181

8282
edit_menu.add_separator()
8383

84+
edit_menu.add_command('Layout Graph Down Stream', _layout_graph_down, 'L')
85+
edit_menu.add_command('Layout Graph Up Stream', _layout_graph_up, 'Ctrl+L')
86+
87+
edit_menu.add_separator()
88+
8489
edit_menu.add_command('Jump In', _jump_in, 'I')
8590
edit_menu.add_command('Jump Out', _jump_out, 'O')
8691

@@ -275,11 +280,38 @@ def _toggle_grid(graph):
275280
graph.display_grid(not graph.scene().grid)
276281

277282

283+
def __layout_graph(graph, down_stream=True):
284+
node_space = graph.get_node_space()
285+
if node_space is not None:
286+
nodes = node_space.children()
287+
else:
288+
nodes = graph.all_nodes()
289+
if not nodes:
290+
return
291+
node_views = [n.view for n in nodes]
292+
nodes_center0 = graph.viewer().nodes_rect_center(node_views)
293+
if down_stream:
294+
auto_layout_down(all_nodes=nodes)
295+
else:
296+
auto_layout_up(all_nodes=nodes)
297+
nodes_center1 = graph.viewer().nodes_rect_center(node_views)
298+
dx = nodes_center0[0] - nodes_center1[0]
299+
dy = nodes_center0[1] - nodes_center1[1]
300+
[n.set_pos(n.x_pos() + dx, n.y_pos()+dy) for n in nodes]
301+
302+
303+
def _layout_graph_down(graph):
304+
__layout_graph(graph, True)
305+
306+
307+
def _layout_graph_up(graph):
308+
__layout_graph(graph, False)
309+
278310
# topological_sort
279311

280312
def get_input_nodes(node):
281313
"""
282-
Get input nodes of a node.
314+
Get input nodes of node.
283315
284316
Args:
285317
node (NodeGraphQt.BaseNode).
@@ -297,7 +329,7 @@ def get_input_nodes(node):
297329

298330
def get_output_nodes(node):
299331
"""
300-
Get output nodes of a node.
332+
Get output nodes of node.
301333
302334
Args:
303335
node (NodeGraphQt.BaseNode).
@@ -399,7 +431,7 @@ def _build_up_stream_graph(start_nodes):
399431

400432
def _sort_nodes(graph, start_nodes, reverse=True):
401433
"""
402-
Sort nodes by graph.
434+
Sort nodes in graph.
403435
404436
Args:
405437
graph (dict): generate from '_build_up_stream_graph' or '_build_down_stream_graph'.
@@ -416,17 +448,17 @@ def _sort_nodes(graph, start_nodes, reverse=True):
416448

417449
sorted_nodes = []
418450

419-
def dfs(graph, start_node):
451+
def dfs(start_node):
420452
for end_node in graph[start_node]:
421453
if not visit[end_node]:
422454
visit[end_node] = True
423-
dfs(graph, end_node)
455+
dfs(end_node)
424456
sorted_nodes.append(start_node)
425457

426458
for start_node in start_nodes:
427459
if not visit[start_node]:
428460
visit[start_node] = True
429-
dfs(graph, start_node)
461+
dfs(start_node)
430462

431463
if reverse:
432464
sorted_nodes.reverse()
@@ -440,7 +472,7 @@ def topological_sort_by_down(start_nodes=[], all_nodes=[]):
440472
'start_nodes' and 'all_nodes' only one needs to be given.
441473
442474
Args:
443-
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update node of the graph.
475+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update nodes of the graph.
444476
all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
445477
Returns:
446478
list[NodeGraphQt.BaseNode]: sorted nodes.
@@ -464,7 +496,7 @@ def topological_sort_by_up(start_nodes=[], all_nodes=[]):
464496
'start_nodes' and 'all_nodes' only one needs to be given.
465497
466498
Args:
467-
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end update node of the graph.
499+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end update nodes of the graph.
468500
all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
469501
Returns:
470502
list[NodeGraphQt.BaseNode]: sorted nodes.
@@ -540,3 +572,139 @@ def update_nodes_by_up(nodes):
540572
_update_nodes(topological_sort_by_up(all_nodes=nodes))
541573

542574
# auto layout
575+
576+
577+
def _update_node_rank_down(node, nodes_rank):
578+
rank = nodes_rank[node] + 1
579+
for n in get_output_nodes(node):
580+
if n in nodes_rank:
581+
nodes_rank[n] = max(nodes_rank[n], rank)
582+
else:
583+
nodes_rank[n] = rank
584+
_update_node_rank_down(n, nodes_rank)
585+
586+
587+
def _compute_rank_down(start_nodes):
588+
"""
589+
Compute the rank of the down stream nodes.
590+
591+
Args:
592+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start nodes of the graph.
593+
Returns:
594+
dict{NodeGraphQt.BaseNode: node_rank, ...}
595+
"""
596+
597+
nodes_rank = {}
598+
for node in start_nodes:
599+
nodes_rank[node] = 0
600+
_update_node_rank_down(node, nodes_rank)
601+
return nodes_rank
602+
603+
604+
def _update_node_rank_up(node, nodes_rank):
605+
rank = nodes_rank[node] + 1
606+
for n in get_input_nodes(node):
607+
if n in nodes_rank:
608+
nodes_rank[n] = max(nodes_rank[n], rank)
609+
else:
610+
nodes_rank[n] = rank
611+
_update_node_rank_up(n, nodes_rank)
612+
613+
614+
def _compute_rank_up(start_nodes):
615+
"""
616+
Compute the rank of the up stream nodes.
617+
618+
Args:
619+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end nodes of the graph.
620+
Returns:
621+
dict{NodeGraphQt.BaseNode: node_rank, ...}
622+
"""
623+
624+
nodes_rank = {}
625+
for node in start_nodes:
626+
nodes_rank[node] = 0
627+
_update_node_rank_up(node, nodes_rank)
628+
return nodes_rank
629+
630+
631+
def auto_layout_up(start_nodes=[], all_nodes=[]):
632+
"""
633+
Auto layout the nodes by up stream direction.
634+
635+
Args:
636+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end nodes of the graph.
637+
all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
638+
"""
639+
640+
if not start_nodes:
641+
start_nodes = [n for n in all_nodes if not _has_output_node(n)]
642+
if not start_nodes:
643+
return []
644+
if not [n for n in start_nodes if _has_input_node(n)]:
645+
return start_nodes
646+
647+
nodes_rank = _compute_rank_up(start_nodes)
648+
649+
rank_map = {}
650+
for node, rank in nodes_rank.items():
651+
if rank in rank_map:
652+
rank_map[rank].append(node)
653+
else:
654+
rank_map[rank] = [node]
655+
656+
current_x = 0
657+
node_height = 50
658+
for rank in reversed(range(len(rank_map))):
659+
nodes = rank_map[rank]
660+
max_width = max([node.view.width for node in nodes])
661+
current_x += max_width
662+
current_y = 0
663+
for idx, node in enumerate(nodes):
664+
dy = max(node_height, node.view.height)
665+
current_y += 0 if idx == 0 else dy
666+
node.set_pos(current_x, current_y)
667+
current_y += dy * 0.5 + 10
668+
669+
current_x += max_width * 0.5 + 100
670+
671+
672+
def auto_layout_down(start_nodes=[], all_nodes=[]):
673+
"""
674+
Auto layout the nodes by down stream direction.
675+
676+
Args:
677+
start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update nodes of the graph.
678+
all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
679+
"""
680+
681+
if not start_nodes:
682+
start_nodes = [n for n in all_nodes if not _has_input_node(n)]
683+
if not start_nodes:
684+
return []
685+
if not [n for n in start_nodes if _has_output_node(n)]:
686+
return start_nodes
687+
688+
nodes_rank = _compute_rank_down(start_nodes)
689+
690+
rank_map = {}
691+
for node, rank in nodes_rank.items():
692+
if rank in rank_map:
693+
rank_map[rank].append(node)
694+
else:
695+
rank_map[rank] = [node]
696+
697+
current_x = 0
698+
node_height = 50
699+
for rank in range(len(rank_map)):
700+
nodes = rank_map[rank]
701+
max_width = max([node.view.width for node in nodes])
702+
current_x += max_width
703+
current_y = 0
704+
for idx, node in enumerate(nodes):
705+
dy = max(node_height, node.view.height)
706+
current_y += 0 if idx == 0 else dy
707+
node.set_pos(current_x, current_y)
708+
current_y += dy * 0.5 + 10
709+
710+
current_x += max_width * 0.5 + 100

NodeGraphQt/widgets/viewer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,14 @@ def force_update(self):
900900
def scene_rect(self):
901901
return [self._scene_range.x(), self._scene_range.y(), self._scene_range.width(), self._scene_range.height()]
902902

903+
def scene_center(self):
904+
cent = self._scene_range.center()
905+
return [cent.x(), cent.y()]
906+
907+
def nodes_rect_center(self, nodes):
908+
cent = self._combined_rect(nodes).center()
909+
return [cent.x(), cent.y()]
910+
903911
def set_scene_rect(self, rect):
904912
self._scene_range = QtCore.QRectF(*rect)
905913
self._update_scene()

0 commit comments

Comments
 (0)