Skip to content

Commit 1ed95ea

Browse files
committed
add sub graph serialize/deserialize
1 parent 8b78bbf commit 1ed95ea

File tree

5 files changed

+2342
-42
lines changed

5 files changed

+2342
-42
lines changed

NodeGraphQt/base/graph.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -791,13 +791,14 @@ def format_color(clr):
791791
return node
792792
raise Exception('\n\n>> Cannot find node:\t"{}"\n'.format(node_type))
793793

794-
def add_node(self, node, pos=None):
794+
def add_node(self, node, pos=None, unique_name=True):
795795
"""
796796
Add a node into the node graph.
797797
798798
Args:
799799
node (NodeGraphQt.BaseNode): node object.
800800
pos (list[float]): node x,y position. (optional)
801+
unique_name (bool): make node name unique
801802
"""
802803
assert isinstance(node, NodeObject), 'node must be a Node instance.'
803804

@@ -813,7 +814,8 @@ def add_node(self, node, pos=None):
813814
self.model.set_node_common_properties(node_attrs)
814815

815816
node.set_graph(self)
816-
node.NODE_NAME = self.get_unique_name(node.NODE_NAME)
817+
if unique_name:
818+
node.NODE_NAME = self.get_unique_name(node.NODE_NAME)
817819
node.model._graph_model = self.model
818820
node.model.name = node.NODE_NAME
819821
node.update()
@@ -1009,17 +1011,18 @@ def get_unique_name(self, name):
10091011
regex = re.compile('[\w ]+(?: )*(\d+)')
10101012
search = regex.search(name)
10111013
if not search:
1012-
for x in range(1, len(node_names) + 1):
1014+
for x in range(1, len(node_names) + 2):
10131015
new_name = '{} {}'.format(name, x)
10141016
if new_name not in node_names:
10151017
return new_name
10161018

10171019
version = search.group(1)
10181020
name = name[:len(version) * -1].strip()
1019-
for x in range(1, len(node_names) + 1):
1021+
for x in range(1, len(node_names) + 2):
10201022
new_name = '{} {}'.format(name, x)
10211023
if new_name not in node_names:
10221024
return new_name
1025+
return name + "_"
10231026

10241027
def current_session(self):
10251028
"""
@@ -1063,19 +1066,14 @@ def _serialize(self, nodes):
10631066
continue
10641067
# update the node model.
10651068
n.update_model()
1066-
1067-
nodes_data.update(n.model.to_dict)
1069+
node_dict = n.model.to_dict
10681070

10691071
if isinstance(n, SubGraph):
1070-
subgraph_node = n
1071-
while subgraph_node:
1072-
_subgraph_node = None
1073-
for _n in subgraph_node.children():
1074-
_n.update_model()
1075-
nodes_data.update(_n.model.to_dict)
1076-
if isinstance(_n, SubGraph):
1077-
_subgraph_node = _n
1078-
subgraph_node = _subgraph_node
1072+
children = n.children()
1073+
if children:
1074+
node_dict[n.model.id]['sub_graph'] = self._serialize(children)
1075+
1076+
nodes_data.update(node_dict)
10791077

10801078
for n_id, n_data in nodes_data.items():
10811079
serial_data['nodes'][n_id] = n_data
@@ -1104,27 +1102,29 @@ def _serialize(self, nodes):
11041102

11051103
return serial_data
11061104

1107-
def _deserialize(self, data, relative_pos=False, pos=None):
1105+
def _deserialize(self, data, relative_pos=False, pos=None, set_parent=True):
11081106
"""
11091107
deserialize node data.
11101108
(used internally by the node graph)
11111109
11121110
Args:
11131111
data (dict): node data.
11141112
relative_pos (bool): position node relative to the cursor.
1113+
set_parent (bool): set node parent to current node space.
11151114
11161115
Returns:
11171116
list[NodeGraphQt.Nodes]: list of node instances.
11181117
"""
11191118
nodes = {}
1120-
11211119
# build the nodes.
11221120
for n_id, n_data in data.get('nodes', {}).items():
11231121
identifier = n_data['type_']
11241122
NodeCls = self._node_factory.create_node_instance(identifier)
11251123
if NodeCls:
11261124
node = NodeCls()
11271125
node.NODE_NAME = n_data.get('name', node.NODE_NAME)
1126+
if 'parent_id' in n_data.keys():
1127+
n_data.pop('parent_id')
11281128
# set properties.
11291129
for prop in node.model.properties.keys():
11301130
if prop in n_data.keys():
@@ -1133,18 +1133,17 @@ def _deserialize(self, data, relative_pos=False, pos=None):
11331133
for prop, val in n_data.get('custom', {}).items():
11341134
node.model.set_property(prop, val)
11351135
nodes[n_id] = node
1136-
self.add_node(node, n_data.get('pos'))
1136+
self.add_node(node, n_data.get('pos'), unique_name=set_parent)
11371137
node.set_graph(self)
11381138

1139+
if isinstance(node, SubGraph):
1140+
sub_graph = n_data.get('sub_graph', None)
1141+
if sub_graph:
1142+
children = self._deserialize(sub_graph, relative_pos, pos, False)
1143+
[child.set_parent(node) for child in children]
1144+
11391145
if n_data.get('dynamic_port', None):
11401146
node.set_ports({'input_ports': n_data['input_ports'], 'output_ports': n_data['output_ports']})
1141-
node.model.parent_id = n_data.get('parent_id', None)
1142-
1143-
# set node parent
1144-
all_nodes = {}
1145-
all_nodes.update(nodes)
1146-
all_nodes.update(self._model.nodes)
1147-
[node.set_parent(all_nodes[node.parent_id]) for node in nodes.values() if node.parent_id is not None]
11481147

11491148
# build the connections.
11501149
for connection in data.get('connections', []):
@@ -1171,6 +1170,9 @@ def _deserialize(self, data, relative_pos=False, pos=None):
11711170
self._viewer.move_nodes([n.view for n in node_objs], pos=pos)
11721171
[setattr(n.model, 'pos', n.view.xy_pos) for n in node_objs]
11731172

1173+
if set_parent:
1174+
[node.set_parent(self._current_node_space) for node in node_objs]
1175+
11741176
return node_objs
11751177

11761178
def serialize_session(self):
@@ -1200,7 +1202,13 @@ def save_session(self, file_path):
12001202
Args:
12011203
file_path (str): path to the saved node layout.
12021204
"""
1203-
serialized_data = self._serialize(self.all_nodes())
1205+
root_node = self.root_node()
1206+
if root_node is not None:
1207+
nodes = root_node.children()
1208+
else:
1209+
nodes = self.all_nodes()
1210+
serialized_data = self._serialize(nodes)
1211+
12041212
node_space = self.get_node_space()
12051213
if node_space is not None:
12061214
node_space = node_space.id
@@ -1306,9 +1314,6 @@ def paste_nodes(self):
13061314
self.clear_selection()
13071315
nodes = self._deserialize(serial_data, relative_pos=True)
13081316
[n.set_selected(True) for n in nodes]
1309-
# set node parent
1310-
[n.set_parent(self._current_node_space) for n in nodes]
1311-
13121317
self._undo_stack.endMacro()
13131318

13141319
def duplicate_nodes(self, nodes):

example_auto_nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def find_node(graph, node):
6060
print(graph.get_node_by_path(node.path()))
6161

6262

63+
def print_children(graph, node):
64+
children = node.children()
65+
print(len(children), children)
66+
67+
6368
if __name__ == '__main__':
6469
app = QtWidgets.QApplication([])
6570

@@ -97,6 +102,7 @@ def show_nodes_list(node):
97102
# setup node menu
98103
node_menu = graph.context_nodes_menu()
99104
node_menu.add_command('Enter Node', enter_node, node_class=SubGraphNode)
105+
node_menu.add_command('Print Children', print_children, node_class=SubGraphNode)
100106
node_menu.add_command('Print Functions', print_functions, node_class=ModuleNode)
101107
node_menu.add_command('Cook Node', cook_node, node_class=AutoNode)
102108
node_menu.add_command('Toggle Auto Cook', toggle_auto_cook, node_class=AutoNode)
@@ -108,7 +114,7 @@ def show_nodes_list(node):
108114

109115
# create test nodes
110116
graph.load_session(r'example_auto_nodes/networks/example_SubGraph.json')
111-
graph.get_node_by_path('/root/Input A').cook()
117+
graph.get_node_by_path('/root/Vector1').cook()
112118

113119
# widget used for the node graph.
114120
graph_widget = graph.widget

0 commit comments

Comments
 (0)