Skip to content

Commit ca39c15

Browse files
authored
Merge pull request #339 from jchanvfx/pipe_connection_restriction_#285
Pipe connection restriction #285
2 parents 4c0632b + e937cc8 commit ca39c15

File tree

9 files changed

+696
-36
lines changed

9 files changed

+696
-36
lines changed

NodeGraphQt/base/graph.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def __init__(self, parent=None, **kwargs):
160160
kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack))
161161
self._viewer.set_layout_direction(layout_direction)
162162

163+
# viewer needs a reference to the model port connection constrains
164+
# for the user interaction with the live pipe.
165+
self._viewer.accept_connection_types = self._model.accept_connection_types
166+
self._viewer.reject_connection_types = self._model.reject_connection_types
167+
163168
self._context_menu = {}
164169

165170
self._register_context_menu()
@@ -1143,6 +1148,39 @@ def create_node(self, node_type, name=None, selected=True, color=None,
11431148
node_attrs[node.type_][pname].update(pattrs)
11441149
self.model.set_node_common_properties(node_attrs)
11451150

1151+
accept_types = node.model.__dict__.pop(
1152+
'_TEMP_accept_connection_types'
1153+
)
1154+
for ptype, pdata in accept_types.get(node.type_, {}).items():
1155+
for pname, accept_data in pdata.items():
1156+
for accept_ntype, accept_ndata in accept_data.items():
1157+
for accept_ptype, accept_pnames in accept_ndata.items():
1158+
for accept_pname in accept_pnames:
1159+
self._model.add_port_accept_connection_type(
1160+
port_name=pname,
1161+
port_type=ptype,
1162+
node_type=node.type_,
1163+
accept_pname=accept_pname,
1164+
accept_ptype=accept_ptype,
1165+
accept_ntype=accept_ntype
1166+
)
1167+
reject_types = node.model.__dict__.pop(
1168+
'_TEMP_reject_connection_types'
1169+
)
1170+
for ptype, pdata in reject_types.get(node.type_, {}).items():
1171+
for pname, reject_data in pdata.items():
1172+
for reject_ntype, reject_ndata in reject_data.items():
1173+
for reject_ptype, reject_pnames in reject_ndata.items():
1174+
for reject_pname in reject_pnames:
1175+
self._model.add_port_reject_connection_type(
1176+
port_name=pname,
1177+
port_type=ptype,
1178+
node_type=node.type_,
1179+
reject_pname=reject_pname,
1180+
reject_ptype=reject_ptype,
1181+
reject_ntype=reject_ntype
1182+
)
1183+
11461184
node.NODE_NAME = self.get_unique_name(name or node.NODE_NAME)
11471185
node.model.name = node.NODE_NAME
11481186
node.model.selected = selected
@@ -1207,6 +1245,39 @@ def add_node(self, node, pos=None, selected=True, push_undo=True):
12071245
node_attrs[node.type_][pname].update(pattrs)
12081246
self.model.set_node_common_properties(node_attrs)
12091247

1248+
accept_types = node.model.__dict__.pop(
1249+
'_TEMP_accept_connection_types'
1250+
)
1251+
for ptype, pdata in accept_types.get(node.type_, {}).items():
1252+
for pname, accept_data in pdata.items():
1253+
for accept_ntype, accept_ndata in accept_data.items():
1254+
for accept_ptype, accept_pnames in accept_ndata.items():
1255+
for accept_pname in accept_pnames:
1256+
self._model.add_port_accept_connection_type(
1257+
port_name=pname,
1258+
port_type=ptype,
1259+
node_type=node.type_,
1260+
accept_pname=accept_pname,
1261+
accept_ptype=accept_ptype,
1262+
accept_ntype=accept_ntype
1263+
)
1264+
reject_types = node.model.__dict__.pop(
1265+
'_TEMP_reject_connection_types'
1266+
)
1267+
for ptype, pdata in reject_types.get(node.type_, {}).items():
1268+
for pname, reject_data in pdata.items():
1269+
for reject_ntype, reject_ndata in reject_data.items():
1270+
for reject_ptype, reject_pnames in reject_ndata.items():
1271+
for reject_pname in reject_pnames:
1272+
self._model.add_port_reject_connection_type(
1273+
port_name=pname,
1274+
port_type=ptype,
1275+
node_type=node.type_,
1276+
reject_pname=reject_pname,
1277+
reject_ptype=reject_ptype,
1278+
reject_ntype=reject_ntype
1279+
)
1280+
12101281
node._graph = self
12111282
node.NODE_NAME = self.get_unique_name(node.NODE_NAME)
12121283
node.model._graph_model = self.model
@@ -1554,6 +1625,10 @@ def _serialize(self, nodes):
15541625
serial_data['graph']['pipe_collision'] = self.pipe_collision()
15551626
serial_data['graph']['pipe_slicing'] = self.pipe_slicing()
15561627

1628+
# connection constrains.
1629+
serial_data['graph']['accept_connection_types'] = self.model.accept_connection_types
1630+
serial_data['graph']['reject_connection_types'] = self.model.reject_connection_types
1631+
15571632
# serialize nodes.
15581633
for n in nodes:
15591634
# update the node model.
@@ -1618,6 +1693,12 @@ def _deserialize(self, data, relative_pos=False, pos=None):
16181693
elif attr_name == 'pipe_slicing':
16191694
self.set_pipe_slicing(attr_value)
16201695

1696+
# connection constrains.
1697+
elif attr_name == 'accept_connection_types':
1698+
self.model.accept_connection_types = attr_value
1699+
elif attr_name == 'reject_connection_types':
1700+
self.model.reject_connection_types = attr_value
1701+
16211702
# build the nodes.
16221703
nodes = {}
16231704
for n_id, n_data in data.get('nodes', {}).items():

NodeGraphQt/base/model.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def __init__(self):
107107
'outputs': NodePropWidgetEnum.HIDDEN.value,
108108
}
109109

110+
# temp store connection constrains.
111+
# (deleted when node is added to the graph)
112+
self._TEMP_accept_connection_types = {}
113+
self._TEMP_reject_connection_types = {}
114+
110115
def __repr__(self):
111116
return '<{}(\'{}\') object at {}>'.format(
112117
self.__class__.__name__, self.name, self.id)
@@ -223,6 +228,85 @@ def get_tab_name(self, name):
223228
return
224229
return model.get_node_common_properties(self.type_)[name]['tab']
225230

231+
def add_port_accept_connection_type(
232+
self,
233+
port_name, port_type, node_type,
234+
accept_pname, accept_ptype, accept_ntype
235+
):
236+
"""
237+
Convenience function for adding to the "accept_connection_types" dict.
238+
If the node graph model is unavailable yet then we store it to a
239+
temp var that gets deleted.
240+
241+
Args:
242+
port_name (str): current port name.
243+
port_type (str): current port type.
244+
node_type (str): current port node type.
245+
accept_pname (str):port name to accept.
246+
accept_ptype (str): port type accept.
247+
accept_ntype (str):port node type to accept.
248+
"""
249+
model = self._graph_model
250+
if model:
251+
model.add_port_accept_connection_type(
252+
port_name, port_type, node_type,
253+
accept_pname, accept_ptype, accept_ntype
254+
)
255+
return
256+
257+
connection_data = self._TEMP_accept_connection_types
258+
keys = [node_type, port_type, port_name, accept_ntype]
259+
for key in keys:
260+
if key not in connection_data.keys():
261+
connection_data[key] = {}
262+
connection_data = connection_data[key]
263+
264+
if accept_ptype not in connection_data:
265+
connection_data[accept_ptype] = set([accept_pname])
266+
else:
267+
connection_data[accept_ptype].add(accept_pname)
268+
269+
def add_port_reject_connection_type(
270+
self,
271+
port_name, port_type, node_type,
272+
reject_pname, reject_ptype, reject_ntype
273+
):
274+
"""
275+
Convenience function for adding to the "reject_connection_types" dict.
276+
If the node graph model is unavailable yet then we store it to a
277+
temp var that gets deleted.
278+
279+
Args:
280+
port_name (str): current port name.
281+
port_type (str): current port type.
282+
node_type (str): current port node type.
283+
reject_pname:
284+
reject_ptype:
285+
reject_ntype:
286+
287+
Returns:
288+
289+
"""
290+
model = self._graph_model
291+
if model:
292+
model.add_port_reject_connection_type(
293+
port_name, port_type, node_type,
294+
reject_pname, reject_ptype, reject_ntype
295+
)
296+
return
297+
298+
connection_data = self._TEMP_reject_connection_types
299+
keys = [node_type, port_type, port_name, reject_ntype]
300+
for key in keys:
301+
if key not in connection_data.keys():
302+
connection_data[key] = {}
303+
connection_data = connection_data[key]
304+
305+
if reject_ptype not in connection_data:
306+
connection_data[reject_ptype] = set([reject_pname])
307+
else:
308+
connection_data[reject_ptype].add(reject_pname)
309+
226310
@property
227311
def properties(self):
228312
"""
@@ -352,6 +436,9 @@ class NodeGraphModel(object):
352436
def __init__(self):
353437
self.__common_node_props = {}
354438

439+
self.accept_connection_types = {}
440+
self.reject_connection_types = {}
441+
355442
self.nodes = {}
356443
self.session = ''
357444
self.acyclic = True
@@ -421,6 +508,96 @@ def get_node_common_properties(self, node_type):
421508
"""
422509
return self.__common_node_props.get(node_type)
423510

511+
def add_port_accept_connection_type(
512+
self,
513+
port_name, port_type, node_type,
514+
accept_pname, accept_ptype, accept_ntype
515+
):
516+
"""
517+
Convenience function for adding to the "accept_connection_types" dict.
518+
519+
Args:
520+
port_name (str): current port name.
521+
port_type (str): current port type.
522+
node_type (str): current port node type.
523+
accept_pname (str):port name to accept.
524+
accept_ptype (str): port type accept.
525+
accept_ntype (str):port node type to accept.
526+
"""
527+
connection_data = self.accept_connection_types
528+
keys = [node_type, port_type, port_name, accept_ntype]
529+
for key in keys:
530+
if key not in connection_data.keys():
531+
connection_data[key] = {}
532+
connection_data = connection_data[key]
533+
534+
if accept_ptype not in connection_data:
535+
connection_data[accept_ptype] = set([accept_pname])
536+
else:
537+
connection_data[accept_ptype].add(accept_pname)
538+
539+
def port_accept_connection_types(self, node_type, port_type, port_name):
540+
"""
541+
Convenience function for getting the accepted port types from the
542+
"accept_connection_types" dict.
543+
544+
Args:
545+
node_type (str):
546+
port_type (str):
547+
port_name (str):
548+
549+
Returns:
550+
dict: {<node_type>: {<port_type>: [<port_name>]}}
551+
"""
552+
data = self.accept_connection_types.get(node_type) or {}
553+
accepted_types = data.get(port_type) or {}
554+
return accepted_types.get(port_name) or {}
555+
556+
def add_port_reject_connection_type(
557+
self,
558+
port_name, port_type, node_type,
559+
reject_pname, reject_ptype, reject_ntype
560+
):
561+
"""
562+
Convenience function for adding to the "reject_connection_types" dict.
563+
564+
Args:
565+
port_name (str): current port name.
566+
port_type (str): current port type.
567+
node_type (str): current port node type.
568+
reject_pname (str): port name to reject.
569+
reject_ptype (str): port type to reject.
570+
reject_ntype (str): port node type to reject.
571+
"""
572+
connection_data = self.reject_connection_types
573+
keys = [node_type, port_type, port_name, reject_ntype]
574+
for key in keys:
575+
if key not in connection_data.keys():
576+
connection_data[key] = {}
577+
connection_data = connection_data[key]
578+
579+
if reject_ptype not in connection_data:
580+
connection_data[reject_ptype] = set([reject_pname])
581+
else:
582+
connection_data[reject_ptype].add(reject_pname)
583+
584+
def port_reject_connection_types(self, node_type, port_type, port_name):
585+
"""
586+
Convenience function for getting the accepted port types from the
587+
"reject_connection_types" dict.
588+
589+
Args:
590+
node_type (str):
591+
port_type (str):
592+
port_name (str):
593+
594+
Returns:
595+
dict: {<node_type>: {<port_type>: [<port_name>]}}
596+
"""
597+
data = self.reject_connection_types.get(node_type) or {}
598+
rejected_types = data.get(port_type) or {}
599+
return rejected_types.get(port_name) or {}
600+
424601

425602
if __name__ == '__main__':
426603
p = PortModel(None)

0 commit comments

Comments
 (0)