@@ -107,6 +107,11 @@ def __init__(self):
107107 'outputs' : NodePropWidgetEnum .HIDDEN .value ,
108108 }
109109
110+ # temp store connection constrains.
111+ # (deleted when node is added to the graph)
112+ self ._TEMP_accept_connection_types = {}
113+ self ._TEMP_reject_connection_types = {}
114+
110115 def __repr__ (self ):
111116 return '<{}(\' {}\' ) object at {}>' .format (
112117 self .__class__ .__name__ , self .name , self .id )
@@ -223,6 +228,85 @@ def get_tab_name(self, name):
223228 return
224229 return model .get_node_common_properties (self .type_ )[name ]['tab' ]
225230
231+ def add_port_accept_connection_type (
232+ self ,
233+ port_name , port_type , node_type ,
234+ accept_pname , accept_ptype , accept_ntype
235+ ):
236+ """
237+ Convenience function for adding to the "accept_connection_types" dict.
238+ If the node graph model is unavailable yet then we store it to a
239+ temp var that gets deleted.
240+
241+ Args:
242+ port_name (str): current port name.
243+ port_type (str): current port type.
244+ node_type (str): current port node type.
245+ accept_pname (str):port name to accept.
246+ accept_ptype (str): port type accept.
247+ accept_ntype (str):port node type to accept.
248+ """
249+ model = self ._graph_model
250+ if model :
251+ model .add_port_accept_connection_type (
252+ port_name , port_type , node_type ,
253+ accept_pname , accept_ptype , accept_ntype
254+ )
255+ return
256+
257+ connection_data = self ._TEMP_accept_connection_types
258+ keys = [node_type , port_type , port_name , accept_ntype ]
259+ for key in keys :
260+ if key not in connection_data .keys ():
261+ connection_data [key ] = {}
262+ connection_data = connection_data [key ]
263+
264+ if accept_ptype not in connection_data :
265+ connection_data [accept_ptype ] = set ([accept_pname ])
266+ else :
267+ connection_data [accept_ptype ].add (accept_pname )
268+
269+ def add_port_reject_connection_type (
270+ self ,
271+ port_name , port_type , node_type ,
272+ reject_pname , reject_ptype , reject_ntype
273+ ):
274+ """
275+ Convenience function for adding to the "reject_connection_types" dict.
276+ If the node graph model is unavailable yet then we store it to a
277+ temp var that gets deleted.
278+
279+ Args:
280+ port_name (str): current port name.
281+ port_type (str): current port type.
282+ node_type (str): current port node type.
283+ reject_pname:
284+ reject_ptype:
285+ reject_ntype:
286+
287+ Returns:
288+
289+ """
290+ model = self ._graph_model
291+ if model :
292+ model .add_port_reject_connection_type (
293+ port_name , port_type , node_type ,
294+ reject_pname , reject_ptype , reject_ntype
295+ )
296+ return
297+
298+ connection_data = self ._TEMP_reject_connection_types
299+ keys = [node_type , port_type , port_name , reject_ntype ]
300+ for key in keys :
301+ if key not in connection_data .keys ():
302+ connection_data [key ] = {}
303+ connection_data = connection_data [key ]
304+
305+ if reject_ptype not in connection_data :
306+ connection_data [reject_ptype ] = set ([reject_pname ])
307+ else :
308+ connection_data [reject_ptype ].add (reject_pname )
309+
226310 @property
227311 def properties (self ):
228312 """
@@ -352,6 +436,9 @@ class NodeGraphModel(object):
352436 def __init__ (self ):
353437 self .__common_node_props = {}
354438
439+ self .accept_connection_types = {}
440+ self .reject_connection_types = {}
441+
355442 self .nodes = {}
356443 self .session = ''
357444 self .acyclic = True
@@ -421,6 +508,96 @@ def get_node_common_properties(self, node_type):
421508 """
422509 return self .__common_node_props .get (node_type )
423510
511+ def add_port_accept_connection_type (
512+ self ,
513+ port_name , port_type , node_type ,
514+ accept_pname , accept_ptype , accept_ntype
515+ ):
516+ """
517+ Convenience function for adding to the "accept_connection_types" dict.
518+
519+ Args:
520+ port_name (str): current port name.
521+ port_type (str): current port type.
522+ node_type (str): current port node type.
523+ accept_pname (str):port name to accept.
524+ accept_ptype (str): port type accept.
525+ accept_ntype (str):port node type to accept.
526+ """
527+ connection_data = self .accept_connection_types
528+ keys = [node_type , port_type , port_name , accept_ntype ]
529+ for key in keys :
530+ if key not in connection_data .keys ():
531+ connection_data [key ] = {}
532+ connection_data = connection_data [key ]
533+
534+ if accept_ptype not in connection_data :
535+ connection_data [accept_ptype ] = set ([accept_pname ])
536+ else :
537+ connection_data [accept_ptype ].add (accept_pname )
538+
539+ def port_accept_connection_types (self , node_type , port_type , port_name ):
540+ """
541+ Convenience function for getting the accepted port types from the
542+ "accept_connection_types" dict.
543+
544+ Args:
545+ node_type (str):
546+ port_type (str):
547+ port_name (str):
548+
549+ Returns:
550+ dict: {<node_type>: {<port_type>: [<port_name>]}}
551+ """
552+ data = self .accept_connection_types .get (node_type ) or {}
553+ accepted_types = data .get (port_type ) or {}
554+ return accepted_types .get (port_name ) or {}
555+
556+ def add_port_reject_connection_type (
557+ self ,
558+ port_name , port_type , node_type ,
559+ reject_pname , reject_ptype , reject_ntype
560+ ):
561+ """
562+ Convenience function for adding to the "reject_connection_types" dict.
563+
564+ Args:
565+ port_name (str): current port name.
566+ port_type (str): current port type.
567+ node_type (str): current port node type.
568+ reject_pname (str): port name to reject.
569+ reject_ptype (str): port type to reject.
570+ reject_ntype (str): port node type to reject.
571+ """
572+ connection_data = self .reject_connection_types
573+ keys = [node_type , port_type , port_name , reject_ntype ]
574+ for key in keys :
575+ if key not in connection_data .keys ():
576+ connection_data [key ] = {}
577+ connection_data = connection_data [key ]
578+
579+ if reject_ptype not in connection_data :
580+ connection_data [reject_ptype ] = set ([reject_pname ])
581+ else :
582+ connection_data [reject_ptype ].add (reject_pname )
583+
584+ def port_reject_connection_types (self , node_type , port_type , port_name ):
585+ """
586+ Convenience function for getting the accepted port types from the
587+ "reject_connection_types" dict.
588+
589+ Args:
590+ node_type (str):
591+ port_type (str):
592+ port_name (str):
593+
594+ Returns:
595+ dict: {<node_type>: {<port_type>: [<port_name>]}}
596+ """
597+ data = self .reject_connection_types .get (node_type ) or {}
598+ rejected_types = data .get (port_type ) or {}
599+ return rejected_types .get (port_name ) or {}
600+
424601
425602if __name__ == '__main__' :
426603 p = PortModel (None )
0 commit comments