@@ -76,11 +76,16 @@ def setup_context_menu(graph):
7676 edit_menu .add_command ('Deselect all' , _clear_node_selection , 'Ctrl+Shift+A' )
7777 edit_menu .add_command ('Enable/Disable' , _disable_nodes , 'D' )
7878
79- edit_menu .add_command ('Duplicate' , _duplicate_nodes , 'Alt+c ' )
79+ edit_menu .add_command ('Duplicate' , _duplicate_nodes , 'Alt+C ' )
8080 edit_menu .add_command ('Center Selection' , _fit_to_selection , 'F' )
8181
8282 edit_menu .add_separator ()
8383
84+ edit_menu .add_command ('Layout Graph Down Stream' , _layout_graph_down , 'L' )
85+ edit_menu .add_command ('Layout Graph Up Stream' , _layout_graph_up , 'Ctrl+L' )
86+
87+ edit_menu .add_separator ()
88+
8489 edit_menu .add_command ('Jump In' , _jump_in , 'I' )
8590 edit_menu .add_command ('Jump Out' , _jump_out , 'O' )
8691
@@ -275,11 +280,38 @@ def _toggle_grid(graph):
275280 graph .display_grid (not graph .scene ().grid )
276281
277282
283+ def __layout_graph (graph , down_stream = True ):
284+ node_space = graph .get_node_space ()
285+ if node_space is not None :
286+ nodes = node_space .children ()
287+ else :
288+ nodes = graph .all_nodes ()
289+ if not nodes :
290+ return
291+ node_views = [n .view for n in nodes ]
292+ nodes_center0 = graph .viewer ().nodes_rect_center (node_views )
293+ if down_stream :
294+ auto_layout_down (all_nodes = nodes )
295+ else :
296+ auto_layout_up (all_nodes = nodes )
297+ nodes_center1 = graph .viewer ().nodes_rect_center (node_views )
298+ dx = nodes_center0 [0 ] - nodes_center1 [0 ]
299+ dy = nodes_center0 [1 ] - nodes_center1 [1 ]
300+ [n .set_pos (n .x_pos () + dx , n .y_pos ()+ dy ) for n in nodes ]
301+
302+
303+ def _layout_graph_down (graph ):
304+ __layout_graph (graph , True )
305+
306+
307+ def _layout_graph_up (graph ):
308+ __layout_graph (graph , False )
309+
278310# topological_sort
279311
280312def get_input_nodes (node ):
281313 """
282- Get input nodes of a node.
314+ Get input nodes of node.
283315
284316 Args:
285317 node (NodeGraphQt.BaseNode).
@@ -297,7 +329,7 @@ def get_input_nodes(node):
297329
298330def get_output_nodes (node ):
299331 """
300- Get output nodes of a node.
332+ Get output nodes of node.
301333
302334 Args:
303335 node (NodeGraphQt.BaseNode).
@@ -399,7 +431,7 @@ def _build_up_stream_graph(start_nodes):
399431
400432def _sort_nodes (graph , start_nodes , reverse = True ):
401433 """
402- Sort nodes by graph.
434+ Sort nodes in graph.
403435
404436 Args:
405437 graph (dict): generate from '_build_up_stream_graph' or '_build_down_stream_graph'.
@@ -416,17 +448,17 @@ def _sort_nodes(graph, start_nodes, reverse=True):
416448
417449 sorted_nodes = []
418450
419- def dfs (graph , start_node ):
451+ def dfs (start_node ):
420452 for end_node in graph [start_node ]:
421453 if not visit [end_node ]:
422454 visit [end_node ] = True
423- dfs (graph , end_node )
455+ dfs (end_node )
424456 sorted_nodes .append (start_node )
425457
426458 for start_node in start_nodes :
427459 if not visit [start_node ]:
428460 visit [start_node ] = True
429- dfs (graph , start_node )
461+ dfs (start_node )
430462
431463 if reverse :
432464 sorted_nodes .reverse ()
@@ -440,7 +472,7 @@ def topological_sort_by_down(start_nodes=[], all_nodes=[]):
440472 'start_nodes' and 'all_nodes' only one needs to be given.
441473
442474 Args:
443- start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update node of the graph.
475+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update nodes of the graph.
444476 all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
445477 Returns:
446478 list[NodeGraphQt.BaseNode]: sorted nodes.
@@ -464,7 +496,7 @@ def topological_sort_by_up(start_nodes=[], all_nodes=[]):
464496 'start_nodes' and 'all_nodes' only one needs to be given.
465497
466498 Args:
467- start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end update node of the graph.
499+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end update nodes of the graph.
468500 all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
469501 Returns:
470502 list[NodeGraphQt.BaseNode]: sorted nodes.
@@ -540,3 +572,139 @@ def update_nodes_by_up(nodes):
540572 _update_nodes (topological_sort_by_up (all_nodes = nodes ))
541573
542574# auto layout
575+
576+
577+ def _update_node_rank_down (node , nodes_rank ):
578+ rank = nodes_rank [node ] + 1
579+ for n in get_output_nodes (node ):
580+ if n in nodes_rank :
581+ nodes_rank [n ] = max (nodes_rank [n ], rank )
582+ else :
583+ nodes_rank [n ] = rank
584+ _update_node_rank_down (n , nodes_rank )
585+
586+
587+ def _compute_rank_down (start_nodes ):
588+ """
589+ Compute the rank of the down stream nodes.
590+
591+ Args:
592+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start nodes of the graph.
593+ Returns:
594+ dict{NodeGraphQt.BaseNode: node_rank, ...}
595+ """
596+
597+ nodes_rank = {}
598+ for node in start_nodes :
599+ nodes_rank [node ] = 0
600+ _update_node_rank_down (node , nodes_rank )
601+ return nodes_rank
602+
603+
604+ def _update_node_rank_up (node , nodes_rank ):
605+ rank = nodes_rank [node ] + 1
606+ for n in get_input_nodes (node ):
607+ if n in nodes_rank :
608+ nodes_rank [n ] = max (nodes_rank [n ], rank )
609+ else :
610+ nodes_rank [n ] = rank
611+ _update_node_rank_up (n , nodes_rank )
612+
613+
614+ def _compute_rank_up (start_nodes ):
615+ """
616+ Compute the rank of the up stream nodes.
617+
618+ Args:
619+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end nodes of the graph.
620+ Returns:
621+ dict{NodeGraphQt.BaseNode: node_rank, ...}
622+ """
623+
624+ nodes_rank = {}
625+ for node in start_nodes :
626+ nodes_rank [node ] = 0
627+ _update_node_rank_up (node , nodes_rank )
628+ return nodes_rank
629+
630+
631+ def auto_layout_up (start_nodes = [], all_nodes = []):
632+ """
633+ Auto layout the nodes by up stream direction.
634+
635+ Args:
636+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the end nodes of the graph.
637+ all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
638+ """
639+
640+ if not start_nodes :
641+ start_nodes = [n for n in all_nodes if not _has_output_node (n )]
642+ if not start_nodes :
643+ return []
644+ if not [n for n in start_nodes if _has_input_node (n )]:
645+ return start_nodes
646+
647+ nodes_rank = _compute_rank_up (start_nodes )
648+
649+ rank_map = {}
650+ for node , rank in nodes_rank .items ():
651+ if rank in rank_map :
652+ rank_map [rank ].append (node )
653+ else :
654+ rank_map [rank ] = [node ]
655+
656+ current_x = 0
657+ node_height = 50
658+ for rank in reversed (range (len (rank_map ))):
659+ nodes = rank_map [rank ]
660+ max_width = max ([node .view .width for node in nodes ])
661+ current_x += max_width
662+ current_y = 0
663+ for idx , node in enumerate (nodes ):
664+ dy = max (node_height , node .view .height )
665+ current_y += 0 if idx == 0 else dy
666+ node .set_pos (current_x , current_y )
667+ current_y += dy * 0.5 + 10
668+
669+ current_x += max_width * 0.5 + 100
670+
671+
672+ def auto_layout_down (start_nodes = [], all_nodes = []):
673+ """
674+ Auto layout the nodes by down stream direction.
675+
676+ Args:
677+ start_nodes (list[NodeGraphQt.BaseNode])(Optional): the start update nodes of the graph.
678+ all_nodes (list[NodeGraphQt.BaseNode])(Optional): if 'start_nodes' is None the function can calculate start nodes from 'all_nodes'.
679+ """
680+
681+ if not start_nodes :
682+ start_nodes = [n for n in all_nodes if not _has_input_node (n )]
683+ if not start_nodes :
684+ return []
685+ if not [n for n in start_nodes if _has_output_node (n )]:
686+ return start_nodes
687+
688+ nodes_rank = _compute_rank_down (start_nodes )
689+
690+ rank_map = {}
691+ for node , rank in nodes_rank .items ():
692+ if rank in rank_map :
693+ rank_map [rank ].append (node )
694+ else :
695+ rank_map [rank ] = [node ]
696+
697+ current_x = 0
698+ node_height = 50
699+ for rank in range (len (rank_map )):
700+ nodes = rank_map [rank ]
701+ max_width = max ([node .view .width for node in nodes ])
702+ current_x += max_width
703+ current_y = 0
704+ for idx , node in enumerate (nodes ):
705+ dy = max (node_height , node .view .height )
706+ current_y += 0 if idx == 0 else dy
707+ node .set_pos (current_x , current_y )
708+ current_y += dy * 0.5 + 10
709+
710+ current_x += max_width * 0.5 + 100
0 commit comments