1717from NodeGraphQt .base .node import NodeObject
1818from NodeGraphQt .base .port import Port
1919from NodeGraphQt .constants import (
20- NODE_LAYOUT_DIRECTION ,
21- NODE_LAYOUT_HORIZONTAL ,
22- NODE_LAYOUT_VERTICAL ,
2320 URI_SCHEME ,
2421 URN_SCHEME ,
22+ LayoutDirectionEnum ,
2523 PipeLayoutEnum ,
2624 PortTypeEnum ,
2725 ViewerEnum
@@ -789,9 +787,20 @@ def set_pipe_style(self, style=PipeLayoutEnum.CURVED.value):
789787 style = style if 0 <= style <= pipe_max else PipeLayoutEnum .CURVED .value
790788 self ._viewer .set_pipe_layout (style )
791789
790+ def layout_direction (self ):
791+ """
792+ Return the current node graph layout direction.
793+
794+ Returns:
795+ int: layout direction.
796+ """
797+ return self .model .layout_direction
798+
792799 def set_layout_direction (self , direction ):
793800 """
794801 Sets the node graph layout direction to horizontal or vertical.
802+ This function will also override the layout direction on all
803+ nodes in the current node graph.
795804
796805 Note:
797806 By default node graph direction is set to "NODE_LAYOUT_HORIZONTAL".
@@ -804,9 +813,12 @@ def set_layout_direction(self, direction):
804813 Args:
805814 direction (int): layout direction.
806815 """
807- direction_types = [NODE_LAYOUT_HORIZONTAL , NODE_LAYOUT_VERTICAL ]
816+ direction_types = [e . value for e in LayoutDirectionEnum ]
808817 if direction not in direction_types :
809- direction = NODE_LAYOUT_HORIZONTAL
818+ direction = LayoutDirectionEnum .HORIZONTAL .value
819+ self ._model .layout_direction = direction
820+ for node in self .all_nodes ():
821+ node .set_layout_direction (direction )
810822 self ._viewer .set_layout_direction (direction )
811823
812824 def fit_to_selection (self ):
@@ -945,6 +957,9 @@ def format_color(clr):
945957 if pos :
946958 node .model .pos = [float (pos [0 ]), float (pos [1 ])]
947959
960+ # initial node direction layout.
961+ node .model .layout_direction = self .layout_direction ()
962+
948963 node .update ()
949964
950965 if push_undo :
@@ -987,6 +1002,11 @@ def add_node(self, node, pos=None, selected=True, push_undo=True):
9871002 node .NODE_NAME = self .get_unique_name (node .NODE_NAME )
9881003 node .model ._graph_model = self .model
9891004 node .model .name = node .NODE_NAME
1005+
1006+ # initial node direction layout.
1007+ node .model .layout_direction = self .layout_direction ()
1008+
1009+ # update method must be called before it's been added to the viewer.
9901010 node .update ()
9911011
9921012 if push_undo :
@@ -1697,7 +1717,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
16971717
16981718 node_layout_direction = self ._viewer .get_layout_direction ()
16991719
1700- if node_layout_direction is NODE_LAYOUT_HORIZONTAL :
1720+ if node_layout_direction is LayoutDirectionEnum . HORIZONTAL . value :
17011721 current_x = 0
17021722 node_height = 120
17031723 for rank in sorted (range (len (rank_map )), reverse = not down_stream ):
@@ -1712,7 +1732,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
17121732 current_y += dy * 0.5 + 10
17131733
17141734 current_x += max_width * 0.5 + 100
1715- elif node_layout_direction is NODE_LAYOUT_VERTICAL :
1735+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
17161736 current_y = 0
17171737 node_width = 250
17181738 for rank in sorted (range (len (rank_map )), reverse = not down_stream ):
@@ -1992,9 +2012,9 @@ def _build_port_nodes(self):
19922012 input_nodes [port .name ()] = input_node
19932013 self .add_node (input_node , selected = False , push_undo = False )
19942014 x , y = input_node .pos ()
1995- if node_layout_direction is NODE_LAYOUT_HORIZONTAL :
2015+ if node_layout_direction is LayoutDirectionEnum . HORIZONTAL . value :
19962016 x -= 100
1997- elif node_layout_direction is NODE_LAYOUT_VERTICAL :
2017+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
19982018 y -= 100
19992019 input_node .set_property ('pos' , [x , y ], push_undo = False )
20002020
@@ -2010,9 +2030,9 @@ def _build_port_nodes(self):
20102030 output_nodes [port .name ()] = output_node
20112031 self .add_node (output_node , selected = False , push_undo = False )
20122032 x , y = output_node .pos ()
2013- if node_layout_direction is NODE_LAYOUT_HORIZONTAL :
2033+ if node_layout_direction is LayoutDirectionEnum . HORIZONTAL . value :
20142034 x += 100
2015- elif node_layout_direction is NODE_LAYOUT_VERTICAL :
2035+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
20162036 y += 100
20172037 output_node .set_property ('pos' , [x , y ], push_undo = False )
20182038
0 commit comments