2121"""
2222
2323from itertools import product
24- from typing import Union , Sequence
24+ from typing import Union , Sequence , Set
2525
2626from brainpy .nn import graph_flow
2727from brainpy .nn .base import Node , Network , FrozenNetwork
28+ from brainpy .nn .datatypes import SingleData
2829from brainpy .nn .nodes .base import Select , Concat
2930from brainpy .types import Tensor
3031
@@ -48,8 +49,8 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]],
4849
4950 # check receivers
5051 if isinstance (receivers , (tuple , list )):
51- raise ValueError ('Cannot concatenate a list/tuple of receivers. '
52- 'Please use set to wrap multiple receivers instead.' )
52+ raise TypeError ('Cannot concatenate a list/tuple of receivers. '
53+ 'Please use set to wrap multiple receivers instead.' )
5354 elif isinstance (receivers , set ):
5455 receivers = list (receivers )
5556 elif isinstance (receivers , Node ):
@@ -105,6 +106,74 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]],
105106 return all_nodes , all_ff_edges , all_fb_edges , all_senders , all_receivers
106107
107108
109+ def _reorganize_many2one (ff_edges , fb_edges ):
110+ """Reorganize the many-to-one connections.
111+
112+ If some node whose "data_type" is :py:class:`brainpy.nn.datatypes.SingleData` receives
113+ multiple feedforward or feedback connections, we should concatenate all feedforward
114+ inputs (or feedback inputs) into one instance of :py:class:`brainpy.nn.Concat`, then
115+ the new Concat instance feeds into this node.
116+
117+ """
118+ from brainpy .nn .nodes .base import Concat
119+
120+ new_nodes = []
121+
122+ # find parents according to the child
123+ ff_senders = dict ()
124+ for edge in ff_edges :
125+ sender , receiver = edge
126+ if receiver not in ff_senders :
127+ ff_senders [receiver ] = [sender ]
128+ else :
129+ ff_senders [receiver ].append (sender )
130+ for receiver , senders in ff_senders .items ():
131+ if isinstance (receiver .data_pass , SingleData ):
132+ if len (senders ) > 1 :
133+ concat_nodes = [node for node in senders if isinstance (node , Concat )]
134+ if len (concat_nodes ) == 1 :
135+ concat = concat_nodes [0 ]
136+ for sender in senders :
137+ if sender != concat :
138+ ff_edges .remove ((sender , receiver ))
139+ ff_edges .add ((sender , concat ))
140+ else :
141+ concat = Concat ()
142+ for sender in senders :
143+ ff_edges .remove ((sender , receiver ))
144+ ff_edges .add ((sender , concat ))
145+ ff_edges .add ((concat , receiver ))
146+ new_nodes .append (concat )
147+
148+ # find parents according to the child
149+ fb_senders = dict ()
150+ for edge in fb_edges :
151+ sender , receiver = edge
152+ if receiver not in fb_senders :
153+ fb_senders [receiver ] = [sender ]
154+ else :
155+ fb_senders [receiver ].append (sender )
156+ for receiver , senders in fb_senders .items ():
157+ if isinstance (receiver .data_pass , SingleData ):
158+ if len (senders ) > 1 :
159+ concat_nodes = [node for node in senders if isinstance (node , Concat )]
160+ if len (concat_nodes ) == 1 :
161+ concat = concat_nodes [0 ]
162+ for sender in senders :
163+ if sender != concat :
164+ fb_edges .remove ((sender , receiver ))
165+ ff_edges .add ((sender , concat ))
166+ else :
167+ concat = Concat ()
168+ for sender in senders :
169+ fb_edges .remove ((sender , receiver ))
170+ ff_edges .add ((sender , concat ))
171+ fb_edges .add ((concat , receiver ))
172+ new_nodes .append (concat )
173+
174+ return new_nodes , ff_edges , fb_edges
175+
176+
108177def merge (
109178 node : Node ,
110179 * other_nodes : Node ,
@@ -170,6 +239,10 @@ def merge(
170239 elif isinstance (n , Node ):
171240 all_nodes .add (n )
172241
242+ # reorganize
243+ new_nodes , all_ff_edges , all_fb_edges = _reorganize_many2one (all_ff_edges , all_fb_edges )
244+ all_nodes .update (new_nodes )
245+
173246 # detect cycles in the graph flow
174247 all_nodes = tuple (all_nodes )
175248 all_ff_edges = tuple (all_ff_edges )
@@ -198,8 +271,8 @@ def merge(
198271
199272
200273def ff_connect (
201- senders : Union [Node , Sequence [Node ]],
202- receivers : Union [Node , Sequence [Node ]],
274+ senders : Union [Node , Sequence [Node ], Set [ Node ] ],
275+ receivers : Union [Node , Set [Node ]],
203276 inplace : bool = False ,
204277 name : str = None ,
205278 need_detect_cycle = True
@@ -246,7 +319,7 @@ def ff_connect(
246319
247320 - In the case of "one-to-many" feedforward connection, `node2` only support
248321 a set of node. Using list or tuple to wrap multiple receivers will concatenate
249- all nodes in the receiver end. This will cause errors.
322+ all nodes in the receiver end. This will cause errors::
250323
251324 # wrong operation of one-to-many
252325 network = node_in >> {node1, node2, ..., node_N}
@@ -296,6 +369,10 @@ def ff_connect(
296369 # all inputs from subgraph 2.
297370 all_ff_edges |= new_ff_edges
298371
372+ # reorganize
373+ new_nodes , all_ff_edges , all_fb_edges = _reorganize_many2one (all_ff_edges , all_fb_edges )
374+ all_nodes .update (new_nodes )
375+
299376 # detect cycles in the graph flow
300377 all_nodes = tuple (all_nodes )
301378 all_ff_edges = tuple (all_ff_edges )
@@ -326,8 +403,8 @@ def ff_connect(
326403
327404
328405def fb_connect (
329- senders : Union [Node , Sequence [Node ]],
330- receivers : Union [Node , Sequence [Node ]],
406+ senders : Union [Node , Sequence [Node ], Set [ Node ] ],
407+ receivers : Union [Node , Set [Node ]],
331408 inplace : bool = False ,
332409 name : str = None ,
333410 need_detect_cycle = True
@@ -380,10 +457,10 @@ def fb_connect(
380457 f'support feedback connections.' )
381458
382459 # detect feedforward cycle
383- all_nodes = tuple (all_nodes )
384- all_ff_edges = tuple (all_ff_edges )
385460 if need_detect_cycle :
386- if graph_flow .detect_cycle (all_nodes , all_ff_edges ):
461+ all_nodes1 = list (all_nodes )
462+ all_ff_edges1 = tuple (all_ff_edges )
463+ if graph_flow .detect_cycle (all_nodes1 , all_ff_edges1 ):
387464 raise ValueError ('We detect cycles in feedforward connections. '
388465 'Maybe you should replace some connection with '
389466 'as feedback ones.' )
@@ -394,7 +471,13 @@ def fb_connect(
394471 # all inputs from subgraph 2.
395472 all_fb_edges |= new_fb_edges
396473
474+ # reorganize
475+ new_nodes , all_ff_edges , all_fb_edges = _reorganize_many2one (all_ff_edges , all_fb_edges )
476+ all_nodes .update (new_nodes )
477+
397478 # detect cycles in the graph flow
479+ all_nodes = tuple (all_nodes )
480+ all_ff_edges = tuple (all_ff_edges )
398481 all_fb_edges = tuple (all_fb_edges )
399482 if need_detect_cycle :
400483 if graph_flow .detect_cycle (all_nodes , all_fb_edges ):
0 commit comments