Skip to content

Commit 8b78bbf

Browse files
committed
update sub graph logic
1 parent 4d18d59 commit 8b78bbf

File tree

15 files changed

+482
-294
lines changed

15 files changed

+482
-294
lines changed

NodeGraphQt/base/graph.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self, parent=None):
145145
tab = QtWidgets.QShortcut(QtGui.QKeySequence(QtCore.Qt.Key_Tab), self._viewer)
146146
tab.activated.connect(self._toggle_tab_search)
147147
self._viewer.need_show_tab_search.connect(self._toggle_tab_search)
148-
148+
149149
self._wire_signals()
150150
self._node_space_bar = node_space_bar(self)
151151

@@ -370,7 +370,9 @@ def widget(self):
370370

371371
layout = QtWidgets.QVBoxLayout(self._widget)
372372
layout.setContentsMargins(0, 0, 0, 0)
373-
layout.addWidget(self._node_space_bar)
373+
layout.setSpacing(0)
374+
if self.root_node() is not None:
375+
layout.addWidget(self._node_space_bar)
374376
layout.addWidget(self._viewer)
375377
return self._widget
376378

@@ -852,6 +854,8 @@ def delete_node(self, node):
852854
"""
853855
assert isinstance(node, NodeObject), \
854856
'node must be a instance of a NodeObject.'
857+
if node is self.root_node():
858+
return
855859
self.nodes_deleted.emit([node.id])
856860
if isinstance(node, SubGraph):
857861
self._undo_stack.beginMacro('delete sub graph')
@@ -868,10 +872,11 @@ def delete_nodes(self, nodes):
868872
Args:
869873
nodes (list[NodeGraphQt.BaseNode]): list of node instances.
870874
"""
875+
root_node = self.root_node()
871876
self.nodes_deleted.emit([n.id for n in nodes])
872877
self._undo_stack.beginMacro('delete nodes')
873878
[self.delete_nodes(n.children()) for n in nodes if isinstance(n, SubGraph)]
874-
[self._undo_stack.push(NodeRemovedCmd(self, n)) for n in nodes]
879+
[self._undo_stack.push(NodeRemovedCmd(self, n)) for n in nodes if n is not root_node]
875880
self._undo_stack.endMacro()
876881

877882
def delete_pipe(self, pipe):
@@ -948,18 +953,11 @@ def get_node_by_path(self, node_path):
948953
NodeGraphQt.NodeObject: node object.
949954
"""
950955
names = [name for name in node_path.split("/") if name]
951-
root = names.pop(0)
952-
node = self._current_node_space
956+
names.pop(0)
957+
958+
node = self.root_node()
953959
if node is None:
954-
node = self.get_node_by_name(root)
955-
if node is None:
956-
return None
957-
else:
958-
while True:
959-
parent_node = node.parent()
960-
if parent_node is None:
961-
break
962-
node = parent_node
960+
return None
963961

964962
for name in names:
965963
find = False
@@ -1036,8 +1034,12 @@ def clear_session(self):
10361034
"""
10371035
Clears the current node graph session.
10381036
"""
1037+
root_node = self.root_node()
10391038
for n in self.all_nodes():
1039+
if n is root_node:
1040+
continue
10401041
self._undo_stack.push(NodeRemovedCmd(self, n))
1042+
self.set_node_space(root_node)
10411043
self._undo_stack.clear()
10421044
self._model.session = None
10431045
self.session_changed.emit("")
@@ -1055,12 +1057,26 @@ def _serialize(self, nodes):
10551057
"""
10561058
serial_data = {'nodes': {}, 'connections': []}
10571059
nodes_data = {}
1060+
root_node = self.root_node()
10581061
for n in nodes:
1062+
if n is root_node:
1063+
continue
10591064
# update the node model.
10601065
n.update_model()
10611066

10621067
nodes_data.update(n.model.to_dict)
10631068

1069+
if isinstance(n, SubGraph):
1070+
subgraph_node = n
1071+
while subgraph_node:
1072+
_subgraph_node = None
1073+
for _n in subgraph_node.children():
1074+
_n.update_model()
1075+
nodes_data.update(_n.model.to_dict)
1076+
if isinstance(_n, SubGraph):
1077+
_subgraph_node = _n
1078+
subgraph_node = _subgraph_node
1079+
10641080
for n_id, n_data in nodes_data.items():
10651081
serial_data['nodes'][n_id] = n_data
10661082

@@ -1116,13 +1132,19 @@ def _deserialize(self, data, relative_pos=False, pos=None):
11161132
# set custom properties.
11171133
for prop, val in n_data.get('custom', {}).items():
11181134
node.model.set_property(prop, val)
1119-
11201135
nodes[n_id] = node
11211136
self.add_node(node, n_data.get('pos'))
11221137
node.set_graph(self)
11231138

1139+
if n_data.get('dynamic_port', None):
1140+
node.set_ports({'input_ports': n_data['input_ports'], 'output_ports': n_data['output_ports']})
1141+
node.model.parent_id = n_data.get('parent_id', None)
1142+
11241143
# set node parent
1125-
[node.set_parent_id(node.parent_id) for node in nodes.values()]
1144+
all_nodes = {}
1145+
all_nodes.update(nodes)
1146+
all_nodes.update(self._model.nodes)
1147+
[node.set_parent(all_nodes[node.parent_id]) for node in nodes.values() if node.parent_id is not None]
11261148

11271149
# build the connections.
11281150
for connection in data.get('connections', []):
@@ -1178,18 +1200,19 @@ def save_session(self, file_path):
11781200
Args:
11791201
file_path (str): path to the saved node layout.
11801202
"""
1181-
serliazed_data = self._serialize(self.all_nodes())
1203+
serialized_data = self._serialize(self.all_nodes())
11821204
node_space = self.get_node_space()
11831205
if node_space is not None:
11841206
node_space = node_space.id
1185-
serliazed_data['graph'] = {'node_space': node_space, 'pipe_layout': self._viewer.get_pipe_layout()}
1186-
serliazed_data['graph']['graph_rect'] = self._viewer.scene_rect()
1207+
serialized_data['graph'] = {'node_space': node_space, 'pipe_layout': self._viewer.get_pipe_layout()}
1208+
serialized_data['graph']['graph_rect'] = self._viewer.scene_rect()
11871209
file_path = file_path.strip()
11881210
with open(file_path, 'w') as file_out:
1189-
json.dump(serliazed_data, file_out, indent=2, separators=(',', ':'))
1211+
json.dump(serialized_data, file_out, indent=2, separators=(',', ':'))
11901212

11911213
self._model.session = file_path
11921214
self.session_changed.emit(file_path)
1215+
self._viewer.clear_key_state()
11931216

11941217
def load_session(self, file_path):
11951218
"""
@@ -1226,14 +1249,16 @@ def import_session(self, file_path):
12261249
self._deserialize(layout_data)
12271250

12281251
if 'graph' in layout_data.keys():
1229-
node_space_id = layout_data['graph']['node_space']
1230-
1252+
# node_space_id = layout_data['graph']['node_space']
1253+
12311254
# deserialize graph data
1232-
self.set_node_space(self.get_node_by_id(node_space_id))
1255+
# self.set_node_space(self.get_node_by_id(node_space_id))
1256+
self.set_node_space(self.root_node())
12331257
self._viewer.set_pipe_layout(layout_data['graph']['pipe_layout'])
12341258

12351259
self._viewer.set_scene_rect(layout_data['graph']['graph_rect'])
12361260

1261+
self.set_node_space(self.root_node())
12371262
self._undo_stack.clear()
12381263
self._model.session = file_path
12391264
self.session_changed.emit(file_path)
@@ -1424,4 +1449,13 @@ def set_graph_rect(self, rect):
14241449
Args:
14251450
rect (list): [x, y, width, height].
14261451
"""
1427-
self._viewer.set_scene_rect(rect)
1452+
self._viewer.set_scene_rect(rect)
1453+
1454+
def root_node(self):
1455+
"""
1456+
Get the graph root node.
1457+
1458+
Returns:
1459+
node (BaseNode): node object.
1460+
"""
1461+
return self.get_node_by_id('0' * 13)

NodeGraphQt/base/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, node):
2020
self.multi_connection = False
2121
self.visible = True
2222
self.connected_ports = defaultdict(list)
23+
self.data_type = 'None'
2324

2425
def __repr__(self):
2526
return '<{}(\'{}\') @ {}>'.format(
@@ -60,6 +61,7 @@ def __init__(self):
6061
self.selected = False
6162
self.visible = True
6263
self.parent_id = None
64+
self.dynamic_port = False
6365
self.width = 100.0
6466
self.height = 80.0
6567
self.pos = [0.0, 0.0]
@@ -239,12 +241,16 @@ def to_dict(self):
239241
input_ports = []
240242
output_ports = []
241243
for name, model in node_dict.pop('inputs').items():
242-
input_ports.append(name)
244+
if self.dynamic_port:
245+
input_ports.append({'name': name, 'multi_connection': model.multi_connection,
246+
'display_name': model.display_name, 'data_type': model.data_type})
243247
connected_ports = model.to_dict['connected_ports']
244248
if connected_ports:
245249
inputs[name] = connected_ports
246250
for name, model in node_dict.pop('outputs').items():
247-
output_ports.append(name)
251+
if self.dynamic_port:
252+
output_ports.append({'name': name, 'multi_connection': model.multi_connection,
253+
'display_name': model.display_name, 'data_type': model.data_type})
248254
connected_ports = model.to_dict['connected_ports']
249255
if connected_ports:
250256
outputs[name] = connected_ports
@@ -253,8 +259,9 @@ def to_dict(self):
253259
if outputs:
254260
node_dict['outputs'] = outputs
255261

256-
node_dict['input_ports'] = input_ports
257-
node_dict['output_ports'] = output_ports
262+
if self.dynamic_port:
263+
node_dict['input_ports'] = input_ports
264+
node_dict['output_ports'] = output_ports
258265

259266
custom_props = node_dict.pop('_custom_prop', {})
260267

0 commit comments

Comments
 (0)