1717from NodeGraphQt .base .node import NodeObject
1818from NodeGraphQt .base .port import Port
1919from NodeGraphQt .constants import (
20- NODE_LAYOUT_DIRECTION , NODE_LAYOUT_HORIZONTAL , NODE_LAYOUT_VERTICAL ,
20+ URI_SCHEME ,
21+ URN_SCHEME ,
22+ LayoutDirectionEnum ,
2123 PipeLayoutEnum ,
22- URI_SCHEME , URN_SCHEME ,
2324 PortTypeEnum ,
2425 ViewerEnum
2526)
@@ -125,6 +126,15 @@ def __init__(self, parent=None, **kwargs):
125126 self .setObjectName ('NodeGraph' )
126127 self ._model = (
127128 kwargs .get ('model' ) or NodeGraphModel ())
129+
130+ layout_direction = kwargs .get ('layout_direction' )
131+ if layout_direction :
132+ if layout_direction not in [e .value for e in LayoutDirectionEnum ]:
133+ layout_direction = LayoutDirectionEnum .HORIZONTAL .value
134+ self ._model .layout_direction = layout_direction
135+ else :
136+ layout_direction = self ._model .layout_direction
137+
128138 self ._node_factory = (
129139 kwargs .get ('node_factory' ) or NodeFactory ())
130140
@@ -138,6 +148,7 @@ def __init__(self, parent=None, **kwargs):
138148
139149 self ._viewer = (
140150 kwargs .get ('viewer' ) or NodeViewer (undo_stack = self ._undo_stack ))
151+ self ._viewer .set_layout_direction (layout_direction )
141152
142153 self ._build_context_menu ()
143154 self ._register_builtin_nodes ()
@@ -786,6 +797,52 @@ def set_pipe_style(self, style=PipeLayoutEnum.CURVED.value):
786797 style = style if 0 <= style <= pipe_max else PipeLayoutEnum .CURVED .value
787798 self ._viewer .set_pipe_layout (style )
788799
800+ def layout_direction (self ):
801+ """
802+ Return the current node graph layout direction.
803+
804+ `Implemented in` ``v0.3.0``
805+
806+ See Also:
807+ :meth:`NodeGraph.set_layout_direction`
808+
809+ Returns:
810+ int: layout direction.
811+ """
812+ return self .model .layout_direction
813+
814+ def set_layout_direction (self , direction ):
815+ """
816+ Sets the node graph layout direction to horizontal or vertical.
817+ This function will also override the layout direction on all
818+ nodes in the current node graph.
819+
820+ `Implemented in` ``v0.3.0``
821+
822+ See Also:
823+ :meth:`NodeGraph.layout_direction`,
824+ :meth:`NodeObject.set_layout_direction`
825+
826+ Note:
827+ Node Graph Layout Types:
828+
829+ * :attr:`NodeGraphQt.constants.LayoutDirectionEnum.HORIZONTAL`
830+ * :attr:`NodeGraphQt.constants.LayoutDirectionEnum.VERTICAL`
831+
832+ Warnings:
833+ This function does not register to the undo stack.
834+
835+ Args:
836+ direction (int): layout direction.
837+ """
838+ direction_types = [e .value for e in LayoutDirectionEnum ]
839+ if direction not in direction_types :
840+ direction = LayoutDirectionEnum .HORIZONTAL .value
841+ self ._model .layout_direction = direction
842+ for node in self .all_nodes ():
843+ node .set_layout_direction (direction )
844+ self ._viewer .set_layout_direction (direction )
845+
789846 def fit_to_selection (self ):
790847 """
791848 Sets the zoom level to fit selected nodes.
@@ -853,7 +910,7 @@ def register_node(self, node, alias=None):
853910 Register the node to the :meth:`NodeGraph.node_factory`
854911
855912 Args:
856- node (_NodeGraphQt .NodeObject): node object.
913+ node (NodeGraphQt .NodeObject): node object.
857914 alias (str): custom alias name for the node type.
858915 """
859916 self ._node_factory .register_node (node , alias )
@@ -922,6 +979,9 @@ def format_color(clr):
922979 if pos :
923980 node .model .pos = [float (pos [0 ]), float (pos [1 ])]
924981
982+ # initial node direction layout.
983+ node .model .layout_direction = self .layout_direction ()
984+
925985 node .update ()
926986
927987 if push_undo :
@@ -964,6 +1024,11 @@ def add_node(self, node, pos=None, selected=True, push_undo=True):
9641024 node .NODE_NAME = self .get_unique_name (node .NODE_NAME )
9651025 node .model ._graph_model = self .model
9661026 node .model .name = node .NODE_NAME
1027+
1028+ # initial node direction layout.
1029+ node .model .layout_direction = self .layout_direction ()
1030+
1031+ # update method must be called before it's been added to the viewer.
9671032 node .update ()
9681033
9691034 if push_undo :
@@ -1672,7 +1737,9 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
16721737 else :
16731738 rank_map [rank ] = [node ]
16741739
1675- if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL :
1740+ node_layout_direction = self ._viewer .get_layout_direction ()
1741+
1742+ if node_layout_direction is LayoutDirectionEnum .HORIZONTAL .value :
16761743 current_x = 0
16771744 node_height = 120
16781745 for rank in sorted (range (len (rank_map )), reverse = not down_stream ):
@@ -1687,7 +1754,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
16871754 current_y += dy * 0.5 + 10
16881755
16891756 current_x += max_width * 0.5 + 100
1690- elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL :
1757+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
16911758 current_y = 0
16921759 node_width = 250
16931760 for rank in sorted (range (len (rank_map )), reverse = not down_stream ):
@@ -1861,7 +1928,11 @@ def expand_group_node(self, node):
18611928
18621929 # build new sub graph.
18631930 node_factory = copy .deepcopy (self .node_factory )
1864- sub_graph = SubGraph (self , node = node , node_factory = node_factory )
1931+ layout_direction = self .layout_direction ()
1932+ sub_graph = SubGraph (self ,
1933+ node = node ,
1934+ node_factory = node_factory ,
1935+ layout_direction = layout_direction )
18651936
18661937 # populate the sub graph.
18671938 session = node .get_sub_graph_session ()
@@ -1913,14 +1984,17 @@ class SubGraph(NodeGraph):
19131984 -
19141985 """
19151986
1916- def __init__ (self , parent = None , node = None , node_factory = None ):
1987+ def __init__ (self , parent = None , node = None , node_factory = None , ** kwargs ):
19171988 """
19181989 Args:
19191990 parent (object): object parent.
19201991 node (GroupNode): group node related to this sub graph.
19211992 node_factory (NodeFactory): override node factory.
1993+ **kwargs (dict): additional kwargs.
19221994 """
1923- super (SubGraph , self ).__init__ (parent , node_factory = node_factory )
1995+ super (SubGraph , self ).__init__ (
1996+ parent , node_factory = node_factory , ** kwargs
1997+ )
19241998
19251999 # sub graph attributes.
19262000 self ._node = node
@@ -1953,6 +2027,8 @@ def _build_port_nodes(self):
19532027 Returns:
19542028 tuple(dict, dict): input nodes, output nodes.
19552029 """
2030+ node_layout_direction = self ._viewer .get_layout_direction ()
2031+
19562032 # build the parent input port nodes.
19572033 input_nodes = {n .name (): n for n in
19582034 self .get_nodes_by_type (PortInputNode .type_ )}
@@ -1965,9 +2041,9 @@ def _build_port_nodes(self):
19652041 input_nodes [port .name ()] = input_node
19662042 self .add_node (input_node , selected = False , push_undo = False )
19672043 x , y = input_node .pos ()
1968- if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL :
2044+ if node_layout_direction is LayoutDirectionEnum . HORIZONTAL . value :
19692045 x -= 100
1970- elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL :
2046+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
19712047 y -= 100
19722048 input_node .set_property ('pos' , [x , y ], push_undo = False )
19732049
@@ -1983,9 +2059,9 @@ def _build_port_nodes(self):
19832059 output_nodes [port .name ()] = output_node
19842060 self .add_node (output_node , selected = False , push_undo = False )
19852061 x , y = output_node .pos ()
1986- if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL :
2062+ if node_layout_direction is LayoutDirectionEnum . HORIZONTAL . value :
19872063 x += 100
1988- elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL :
2064+ elif node_layout_direction is LayoutDirectionEnum . VERTICAL . value :
19892065 y += 100
19902066 output_node .set_property ('pos' , [x , y ], push_undo = False )
19912067
@@ -2282,7 +2358,10 @@ def expand_group_node(self, node):
22822358
22832359 # build new sub graph.
22842360 node_factory = copy .deepcopy (self .node_factory )
2285- sub_graph = SubGraph (self , node = node , node_factory = node_factory )
2361+ sub_graph = SubGraph (self ,
2362+ node = node ,
2363+ node_factory = node_factory ,
2364+ layout_direction = self .layout_direction ())
22862365
22872366 # populate the sub graph.
22882367 serialized_session = node .get_sub_graph_session ()
0 commit comments