Skip to content

Commit a270879

Browse files
committed
feat: concatenate multiple inputs of the node with data pass of SingleData
1 parent 6ca4d34 commit a270879

File tree

2 files changed

+280
-11
lines changed

2 files changed

+280
-11
lines changed

brainpy/nn/operations.py

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
"""
2222

2323
from itertools import product
24-
from typing import Union, Sequence
24+
from typing import Union, Sequence, Set
2525

2626
from brainpy.nn import graph_flow
2727
from brainpy.nn.base import Node, Network, FrozenNetwork
28+
from brainpy.nn.datatypes import SingleData
2829
from brainpy.nn.nodes.base import Select, Concat
2930
from brainpy.types import Tensor
3031

@@ -48,8 +49,8 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]],
4849

4950
# check receivers
5051
if isinstance(receivers, (tuple, list)):
51-
raise ValueError('Cannot concatenate a list/tuple of receivers. '
52-
'Please use set to wrap multiple receivers instead.')
52+
raise TypeError('Cannot concatenate a list/tuple of receivers. '
53+
'Please use set to wrap multiple receivers instead.')
5354
elif isinstance(receivers, set):
5455
receivers = list(receivers)
5556
elif isinstance(receivers, Node):
@@ -105,6 +106,74 @@ def _retrieve_nodes_and_edges(senders: Union[Node, Sequence[Node]],
105106
return all_nodes, all_ff_edges, all_fb_edges, all_senders, all_receivers
106107

107108

109+
def _reorganize_many2one(ff_edges, fb_edges):
110+
"""Reorganize the many-to-one connections.
111+
112+
If some node whose "data_type" is :py:class:`brainpy.nn.datatypes.SingleData` receives
113+
multiple feedforward or feedback connections, we should concatenate all feedforward
114+
inputs (or feedback inputs) into one instance of :py:class:`brainpy.nn.Concat`, then
115+
the new Concat instance feeds into this node.
116+
117+
"""
118+
from brainpy.nn.nodes.base import Concat
119+
120+
new_nodes = []
121+
122+
# find parents according to the child
123+
ff_senders = dict()
124+
for edge in ff_edges:
125+
sender, receiver = edge
126+
if receiver not in ff_senders:
127+
ff_senders[receiver] = [sender]
128+
else:
129+
ff_senders[receiver].append(sender)
130+
for receiver, senders in ff_senders.items():
131+
if isinstance(receiver.data_pass, SingleData):
132+
if len(senders) > 1:
133+
concat_nodes = [node for node in senders if isinstance(node, Concat)]
134+
if len(concat_nodes) == 1:
135+
concat = concat_nodes[0]
136+
for sender in senders:
137+
if sender != concat:
138+
ff_edges.remove((sender, receiver))
139+
ff_edges.add((sender, concat))
140+
else:
141+
concat = Concat()
142+
for sender in senders:
143+
ff_edges.remove((sender, receiver))
144+
ff_edges.add((sender, concat))
145+
ff_edges.add((concat, receiver))
146+
new_nodes.append(concat)
147+
148+
# find parents according to the child
149+
fb_senders = dict()
150+
for edge in fb_edges:
151+
sender, receiver = edge
152+
if receiver not in fb_senders:
153+
fb_senders[receiver] = [sender]
154+
else:
155+
fb_senders[receiver].append(sender)
156+
for receiver, senders in fb_senders.items():
157+
if isinstance(receiver.data_pass, SingleData):
158+
if len(senders) > 1:
159+
concat_nodes = [node for node in senders if isinstance(node, Concat)]
160+
if len(concat_nodes) == 1:
161+
concat = concat_nodes[0]
162+
for sender in senders:
163+
if sender != concat:
164+
fb_edges.remove((sender, receiver))
165+
ff_edges.add((sender, concat))
166+
else:
167+
concat = Concat()
168+
for sender in senders:
169+
fb_edges.remove((sender, receiver))
170+
ff_edges.add((sender, concat))
171+
fb_edges.add((concat, receiver))
172+
new_nodes.append(concat)
173+
174+
return new_nodes, ff_edges, fb_edges
175+
176+
108177
def merge(
109178
node: Node,
110179
*other_nodes: Node,
@@ -170,6 +239,10 @@ def merge(
170239
elif isinstance(n, Node):
171240
all_nodes.add(n)
172241

242+
# reorganize
243+
new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges)
244+
all_nodes.update(new_nodes)
245+
173246
# detect cycles in the graph flow
174247
all_nodes = tuple(all_nodes)
175248
all_ff_edges = tuple(all_ff_edges)
@@ -198,8 +271,8 @@ def merge(
198271

199272

200273
def ff_connect(
201-
senders: Union[Node, Sequence[Node]],
202-
receivers: Union[Node, Sequence[Node]],
274+
senders: Union[Node, Sequence[Node], Set[Node]],
275+
receivers: Union[Node, Set[Node]],
203276
inplace: bool = False,
204277
name: str = None,
205278
need_detect_cycle=True
@@ -246,7 +319,7 @@ def ff_connect(
246319
247320
- In the case of "one-to-many" feedforward connection, `node2` only support
248321
a set of node. Using list or tuple to wrap multiple receivers will concatenate
249-
all nodes in the receiver end. This will cause errors.
322+
all nodes in the receiver end. This will cause errors::
250323
251324
# wrong operation of one-to-many
252325
network = node_in >> {node1, node2, ..., node_N}
@@ -296,6 +369,10 @@ def ff_connect(
296369
# all inputs from subgraph 2.
297370
all_ff_edges |= new_ff_edges
298371

372+
# reorganize
373+
new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges)
374+
all_nodes.update(new_nodes)
375+
299376
# detect cycles in the graph flow
300377
all_nodes = tuple(all_nodes)
301378
all_ff_edges = tuple(all_ff_edges)
@@ -326,8 +403,8 @@ def ff_connect(
326403

327404

328405
def fb_connect(
329-
senders: Union[Node, Sequence[Node]],
330-
receivers: Union[Node, Sequence[Node]],
406+
senders: Union[Node, Sequence[Node], Set[Node]],
407+
receivers: Union[Node, Set[Node]],
331408
inplace: bool = False,
332409
name: str = None,
333410
need_detect_cycle=True
@@ -380,10 +457,10 @@ def fb_connect(
380457
f'support feedback connections.')
381458

382459
# detect feedforward cycle
383-
all_nodes = tuple(all_nodes)
384-
all_ff_edges = tuple(all_ff_edges)
385460
if need_detect_cycle:
386-
if graph_flow.detect_cycle(all_nodes, all_ff_edges):
461+
all_nodes1 = list(all_nodes)
462+
all_ff_edges1 = tuple(all_ff_edges)
463+
if graph_flow.detect_cycle(all_nodes1, all_ff_edges1):
387464
raise ValueError('We detect cycles in feedforward connections. '
388465
'Maybe you should replace some connection with '
389466
'as feedback ones.')
@@ -394,7 +471,13 @@ def fb_connect(
394471
# all inputs from subgraph 2.
395472
all_fb_edges |= new_fb_edges
396473

474+
# reorganize
475+
new_nodes, all_ff_edges, all_fb_edges = _reorganize_many2one(all_ff_edges, all_fb_edges)
476+
all_nodes.update(new_nodes)
477+
397478
# detect cycles in the graph flow
479+
all_nodes = tuple(all_nodes)
480+
all_ff_edges = tuple(all_ff_edges)
398481
all_fb_edges = tuple(all_fb_edges)
399482
if need_detect_cycle:
400483
if graph_flow.detect_cycle(all_nodes, all_fb_edges):
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from unittest import TestCase
4+
5+
import brainpy as bp
6+
7+
8+
class TestFF(TestCase):
9+
def test_one2one(self):
10+
i = bp.nn.Input(1)
11+
r = bp.nn.Reservoir(10)
12+
model = i >> r
13+
print(model.lnodes)
14+
self.assertTrue(model.ff_senders[r][0] == i)
15+
self.assertTrue(model.ff_receivers[i][0] == r)
16+
17+
def test_many2one1(self):
18+
i1 = bp.nn.Input(1)
19+
i2 = bp.nn.Input(2)
20+
i3 = bp.nn.Input(3)
21+
r = bp.nn.Reservoir(10)
22+
model = [i1, i2, i3] >> r
23+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
24+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
25+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
26+
27+
def test_many2one2(self):
28+
i1 = bp.nn.Input(1)
29+
i2 = bp.nn.Input(2)
30+
i3 = bp.nn.Input(3)
31+
r = bp.nn.Reservoir(10)
32+
model = (i1, i2, i3) >> r
33+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
34+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
35+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
36+
37+
def test_many2one3(self):
38+
i1 = bp.nn.Input(1)
39+
i2 = bp.nn.Input(2)
40+
i3 = bp.nn.Input(3)
41+
r = bp.nn.Reservoir(10)
42+
model = {i1, i2, i3} >> r
43+
self.assertTrue(model.ff_receivers[i1][0] == r)
44+
self.assertTrue(model.ff_receivers[i2][0] == r)
45+
self.assertTrue(model.ff_receivers[i3][0] == r)
46+
47+
def test_one2many1(self):
48+
i = bp.nn.Input(1)
49+
o1 = bp.nn.Dense(3)
50+
o2 = bp.nn.Dense(4)
51+
o3 = bp.nn.Dense(5)
52+
with self.assertRaises(TypeError):
53+
model = i >> [o1, o2, o3]
54+
55+
def test_one2many2(self):
56+
i = bp.nn.Input(1)
57+
o1 = bp.nn.Dense(3)
58+
o2 = bp.nn.Dense(4)
59+
o3 = bp.nn.Dense(5)
60+
with self.assertRaises(TypeError):
61+
model = i >> (o1, o2, o3)
62+
63+
def test_one2many3(self):
64+
i = bp.nn.Input(1)
65+
o1 = bp.nn.Dense(3)
66+
o2 = bp.nn.Dense(4)
67+
o3 = bp.nn.Dense(5)
68+
model = i >> {o1, o2, o3}
69+
# model.plot_node_graph()
70+
self.assertTrue(model.ff_senders[o1][0] == i)
71+
self.assertTrue(model.ff_senders[o2][0] == i)
72+
self.assertTrue(model.ff_senders[o3][0] == i)
73+
74+
def test_many2many1(self):
75+
i1 = bp.nn.Input(1)
76+
i2 = bp.nn.Input(2)
77+
i3 = bp.nn.Input(3)
78+
79+
o1 = bp.nn.Dense(3)
80+
o2 = bp.nn.Dense(4)
81+
o3 = bp.nn.Dense(5)
82+
83+
model = bp.nn.ff_connect([i1, i2, i3], {o1, o2, o3})
84+
85+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
86+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
87+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
88+
89+
self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat))
90+
self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat))
91+
self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat))
92+
93+
def test_many2many2(self):
94+
i1 = bp.nn.Input(1)
95+
i2 = bp.nn.Input(2)
96+
i3 = bp.nn.Input(3)
97+
98+
o1 = bp.nn.Dense(3)
99+
o2 = bp.nn.Dense(4)
100+
o3 = bp.nn.Dense(5)
101+
102+
model = bp.nn.ff_connect((i1, i2, i3), {o1, o2, o3})
103+
104+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
105+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
106+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
107+
108+
self.assertTrue(isinstance(model.ff_senders[o1][0], bp.nn.Concat))
109+
self.assertTrue(isinstance(model.ff_senders[o2][0], bp.nn.Concat))
110+
self.assertTrue(isinstance(model.ff_senders[o3][0], bp.nn.Concat))
111+
112+
def test_many2many3(self):
113+
i1 = bp.nn.Input(1)
114+
i2 = bp.nn.Input(2)
115+
i3 = bp.nn.Input(3)
116+
117+
o1 = bp.nn.Dense(3)
118+
o2 = bp.nn.Dense(4)
119+
o3 = bp.nn.Dense(5)
120+
121+
model = bp.nn.ff_connect({i1, i2, i3}, {o1, o2, o3})
122+
model.plot_node_graph()
123+
124+
self.assertTrue(len(model.ff_receivers[i1]) == 3)
125+
self.assertTrue(len(model.ff_receivers[i2]) == 3)
126+
self.assertTrue(len(model.ff_receivers[i3]) == 3)
127+
128+
self.assertTrue(len(model.ff_senders[o1]) == 3)
129+
self.assertTrue(len(model.ff_senders[o2]) == 3)
130+
self.assertTrue(len(model.ff_senders[o3]) == 3)
131+
132+
def test_many2one4(self):
133+
i1 = bp.nn.Input(1)
134+
i2 = bp.nn.Input(2)
135+
i3 = bp.nn.Input(3)
136+
137+
ii = bp.nn.Input(3)
138+
139+
model = {i1, i2, i3} >> ii
140+
model.plot_node_graph()
141+
142+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
143+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
144+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
145+
146+
def test_many2one5(self):
147+
i1 = bp.nn.Input(1)
148+
i2 = bp.nn.Input(2)
149+
i3 = bp.nn.Input(3)
150+
ii = bp.nn.Input(3)
151+
152+
model = (i1 >> ii) & (i2 >> ii)
153+
# model.plot_node_graph()
154+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
155+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
156+
self.assertTrue(len(model.ff_senders[ii]) == 1)
157+
self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat))
158+
159+
model = model & (i3 >> ii)
160+
# model.plot_node_graph()
161+
self.assertTrue(isinstance(model.ff_receivers[i1][0], bp.nn.Concat))
162+
self.assertTrue(isinstance(model.ff_receivers[i2][0], bp.nn.Concat))
163+
self.assertTrue(isinstance(model.ff_receivers[i3][0], bp.nn.Concat))
164+
self.assertTrue(len(model.ff_senders[ii]) == 1)
165+
self.assertTrue(isinstance(model.ff_senders[ii][0], bp.nn.Concat))
166+
167+
168+
class TestFB(TestCase):
169+
def test_many2one(self):
170+
class FBNode(bp.nn.Node):
171+
def init_fb_conn(self):
172+
pass
173+
174+
i1 = FBNode()
175+
i2 = FBNode()
176+
i3 = FBNode()
177+
i4 = FBNode()
178+
179+
model = (i1 >> i2 >> i3) & (i1 << i2) & (i1 << i3)
180+
model.plot_node_graph()
181+
182+
model = model & (i3 >> i4) & (i1 << i4)
183+
model.plot_node_graph()
184+
185+
186+

0 commit comments

Comments
 (0)