Skip to content

Commit 822d56f

Browse files
committed
feat: data pass of the Node is default SingleData
1 parent a270879 commit 822d56f

File tree

12 files changed

+49
-50
lines changed

12 files changed

+49
-50
lines changed

brainpy/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Neural Networks (nn)"""
44

55
from .base import *
6-
from .constants import *
6+
from .datatypes import *
77
from .graph_flow import *
88
from .nodes import *
99
from .graph_flow import *

brainpy/nn/base.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
MathError)
2929
from brainpy.nn.algorithms.offline import OfflineAlgorithm
3030
from brainpy.nn.algorithms.online import OnlineAlgorithm
31-
from brainpy.nn.constants import (PASS_SEQUENCE,
32-
DATA_PASS_FUNC,
33-
DATA_PASS_TYPES)
31+
from brainpy.nn.datatypes import (DataType, SingleData, MultipleData)
3432
from brainpy.nn.graph_flow import (find_senders_and_receivers,
3533
find_entries_and_exits,
3634
detect_cycle,
@@ -83,13 +81,13 @@ def feedback(self):
8381
class Node(Base):
8482
"""Basic Node class for neural network building in BrainPy."""
8583

86-
'''Support multiple types of data pass, including "PASS_SEQUENCE" (by default),
87-
"PASS_NAME_DICT", "PASS_NODE_DICT" and user-customized type which registered
88-
by ``brainpy.nn.register_data_pass_type()`` function.
84+
'''Support multiple types of data pass, including "PassOnlyOne" (by default),
85+
"PassSequence", "PassNameDict", etc. and user-customized type which inherits
86+
from basic "SingleData" or "MultipleData".
8987
9088
This setting will change the feedforward/feedback input data which pass into
9189
the "call()" function and the sizes of the feedforward/feedback input data.'''
92-
data_pass_type = PASS_SEQUENCE
90+
data_pass = SingleData()
9391

9492
'''Offline fitting method.'''
9593
offline_fit_by: Union[Callable, OfflineAlgorithm]
@@ -115,11 +113,10 @@ def __init__(
115113
self._trainable = trainable
116114
self._state = None # the state of the current node
117115
self._fb_output = None # the feedback output of the current node
118-
# data pass function
119-
if self.data_pass_type not in DATA_PASS_FUNC:
120-
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
121-
f'Only support {DATA_PASS_TYPES}')
122-
self.data_pass_func = DATA_PASS_FUNC[self.data_pass_type]
116+
# data pass
117+
if not isinstance(self.data_pass, DataType):
118+
raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. '
119+
f'Only support {DataType.__class__}')
123120

124121
# super initialization
125122
super(Node, self).__init__(name=name)
@@ -129,11 +126,10 @@ def __init__(
129126
self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)}
130127

131128
def __repr__(self):
132-
name = type(self).__name__
133-
prefix = ' ' * (len(name) + 1)
134-
line1 = f"{name}(name={self.name}, forwards={self.feedforward_shapes}, \n"
135-
line2 = f"{prefix}feedbacks={self.feedback_shapes}, output={self.output_shape})"
136-
return line1 + line2
129+
return (f"{type(self).__name__}(name={self.name}, "
130+
f"forwards={self.feedforward_shapes}, "
131+
f"feedbacks={self.feedback_shapes}, "
132+
f"output={self.output_shape})")
137133

138134
def __call__(self, *args, **kwargs) -> Tensor:
139135
"""The main computation function of a Node.
@@ -298,7 +294,7 @@ def trainable(self, value: bool):
298294
@property
299295
def feedforward_shapes(self):
300296
"""Input data size."""
301-
return self.data_pass_func(self._feedforward_shapes)
297+
return self.data_pass.filter(self._feedforward_shapes)
302298

303299
@feedforward_shapes.setter
304300
def feedforward_shapes(self, size):
@@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
324320
@property
325321
def feedback_shapes(self):
326322
"""Output data size."""
327-
return self.data_pass_func(self._feedback_shapes)
323+
return self.data_pass.filter(self._feedback_shapes)
328324

329325
@feedback_shapes.setter
330326
def feedback_shapes(self, size):
@@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
530526
f'batch size by ".initialize(num_batch)", or change the data '
531527
f'consistent with the data batch size {self.state.shape[0]}.')
532528
# data
533-
ff = self.data_pass_func(ff)
534-
fb = self.data_pass_func(fb)
529+
ff = self.data_pass.filter(ff)
530+
fb = self.data_pass.filter(fb)
535531
return ff, fb
536532

537533
def _call(self,
@@ -747,6 +743,8 @@ def set_state(self, state):
747743
class Network(Node):
748744
"""Basic Network class for neural network building in BrainPy."""
749745

746+
data_pass = MultipleData('sequence')
747+
750748
def __init__(self,
751749
nodes: Optional[Sequence[Node]] = None,
752750
ff_edges: Optional[Sequence[Tuple[Node]]] = None,
@@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
11451143
check_shape_except_batch(size, fb[k].shape)
11461144

11471145
# data transformation
1148-
ff = self.data_pass_func(ff)
1149-
fb = self.data_pass_func(fb)
1146+
ff = self.data_pass.filter(ff)
1147+
fb = self.data_pass.filter(fb)
11501148
return ff, fb
11511149

11521150
def _call(self,
@@ -1208,12 +1206,12 @@ def _call(self,
12081206
def _call_a_node(self, node, ff, fb, monitors, forced_states,
12091207
parent_outputs, children_queue, ff_senders,
12101208
**shared_kwargs):
1211-
ff = node.data_pass_func(ff)
1209+
ff = node.data_pass.filter(ff)
12121210
if f'{node.name}.inputs' in monitors:
12131211
monitors[f'{node.name}.inputs'] = ff
12141212
# get the output results
12151213
if len(fb):
1216-
fb = node.data_pass_func(fb)
1214+
fb = node.data_pass.filter(fb)
12171215
if f'{node.name}.feedbacks' in monitors:
12181216
monitors[f'{node.name}.feedbacks'] = fb
12191217
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs)
@@ -1440,7 +1438,7 @@ def plot_node_graph(self,
14401438
if len(nodes_untrainable):
14411439
proxie.append(Line2D([], [], color='white', marker='o',
14421440
markerfacecolor=untrainable_color))
1443-
labels.append('Untrainable')
1441+
labels.append('Nontrainable')
14441442
if len(ff_edges):
14451443
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
14461444
labels.append('Feedforward')

brainpy/nn/nodes/ANN/batch_norm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66
"""
77

88

9-
from typing import Sequence, Optional, Dict, Callable, Union
9+
from typing import Union
1010

1111
import jax.nn
1212
import jax.numpy as jnp
1313

14-
import brainpy.math as bm
1514
import brainpy
15+
import brainpy.math as bm
1616
from brainpy.initialize import ZeroInit, OneInit, Initializer
1717
from brainpy.nn.base import Node
18-
from brainpy.nn.constants import PASS_ONLY_ONE
19-
2018

2119
__all__ = [
2220
'BatchNorm',
@@ -40,8 +38,6 @@ class BatchNorm(Node):
4038
beta_init: an initializer generating the original translation matrix
4139
gamma_init: an initializer generating the original scaling matrix
4240
"""
43-
data_pass_type = PASS_ONLY_ONE
44-
4541
def __init__(self,
4642
axis: Union[int, tuple, list],
4743
epsilon: float = 1e-5,

brainpy/nn/nodes/ANN/dropout.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import brainpy.math as bm
44
from brainpy.nn.base import Node
5-
from brainpy.nn.constants import PASS_ONLY_ONE
65

76
__all__ = [
87
'Dropout'
@@ -37,8 +36,6 @@ class Dropout(Node):
3736
neural networks from overfitting." The journal of machine learning
3837
research 15.1 (2014): 1929-1958.
3938
"""
40-
data_pass_type = PASS_ONLY_ONE
41-
4239
def __init__(self, prob, seed=None, **kwargs):
4340
super(Dropout, self).__init__(**kwargs)
4441
self.prob = prob

brainpy/nn/nodes/ANN/rnn_cells.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
init_param,
1212
Initializer)
1313
from brainpy.nn.base import RecurrentNode
14+
from brainpy.nn.datatypes import MultipleData
1415
from brainpy.tools.checking import (check_integer,
1516
check_initializer,
1617
check_shape_consistency)
@@ -55,6 +56,7 @@ class VanillaRNN(RecurrentNode):
5556
Whether set the node is trainable.
5657
5758
"""
59+
data_pass = MultipleData('sequence')
5860

5961
def __init__(
6062
self,
@@ -169,6 +171,7 @@ class GRU(RecurrentNode):
169171
evaluation of gated recurrent neural networks on sequence modeling.
170172
arXiv preprint arXiv:1412.3555.
171173
"""
174+
data_pass = MultipleData('sequence')
172175

173176
def __init__(
174177
self,
@@ -302,6 +305,7 @@ class LSTM(RecurrentNode):
302305
exploration of recurrent network architectures." In International conference
303306
on machine learning, pp. 2342-2350. PMLR, 2015.
304307
"""
308+
data_pass = MultipleData('sequence')
305309

306310
def __init__(
307311
self,
@@ -391,16 +395,16 @@ def c(self, value):
391395

392396

393397
class ConvNDLSTM(RecurrentNode):
394-
pass
398+
data_pass = MultipleData('sequence')
395399

396400

397401
class Conv1DLSTM(ConvNDLSTM):
398-
pass
402+
data_pass = MultipleData('sequence')
399403

400404

401405
class Conv2DLSTM(ConvNDLSTM):
402-
pass
406+
data_pass = MultipleData('sequence')
403407

404408

405409
class Conv3DLSTM(ConvNDLSTM):
406-
pass
410+
data_pass = MultipleData('sequence')

brainpy/nn/nodes/RC/linear_readout.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import brainpy.math as bm
66
from brainpy.errors import MathError
77
from brainpy.initialize import Initializer
8+
from brainpy.nn.datatypes import MultipleData
89
from brainpy.nn.nodes.base.dense import Dense
910
from brainpy.tools.checking import check_shape_consistency
1011

@@ -27,6 +28,7 @@ class LinearReadout(Dense):
2728
trainable: bool
2829
Default is true.
2930
"""
31+
data_pass = MultipleData('sequence')
3032

3133
def __init__(self, num_unit: int, **kwargs):
3234
super(LinearReadout, self).__init__(num_unit=num_unit, **kwargs)

brainpy/nn/nodes/RC/nvar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import brainpy.math as bm
99
from brainpy.nn.base import RecurrentNode
10+
from brainpy.nn.datatypes import MultipleData
1011
from brainpy.tools.checking import (check_shape_consistency,
1112
check_integer,
1213
check_sequence)
@@ -61,6 +62,7 @@ class NVAR(RecurrentNode):
6162
https://doi.org/10.1038/s41467-021-25801-2
6263
6364
"""
65+
data_pass = MultipleData('sequence')
6466

6567
def __init__(
6668
self,

brainpy/nn/nodes/RC/reservoir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import brainpy.math as bm
66
from brainpy.initialize import Normal, ZeroInit, Initializer, init_param
77
from brainpy.nn.base import RecurrentNode
8+
from brainpy.nn.datatypes import MultipleData
89
from brainpy.tools.checking import (check_shape_consistency,
910
check_float,
1011
check_initializer,
@@ -90,6 +91,7 @@ class Reservoir(RecurrentNode):
9091
.. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks."
9192
Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686.
9293
"""
94+
data_pass = MultipleData('sequence')
9395

9496
def __init__(
9597
self,

brainpy/nn/nodes/base/activation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from brainpy.math import activations
66
from brainpy.nn.base import Node
7-
from brainpy.nn.constants import PASS_ONLY_ONE
87

98
__all__ = [
109
'Activation'
@@ -22,8 +21,6 @@ class Activation(Node):
2221
The settings for the activation function.
2322
"""
2423

25-
data_pass_type = PASS_ONLY_ONE
26-
2724
def __init__(self,
2825
activation: str = 'relu',
2926
fun_setting: Optional[Dict[str, Any]] = None,

brainpy/nn/nodes/base/dense.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from brainpy.errors import MathError
1010
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, init_param
1111
from brainpy.nn.base import Node
12+
from brainpy.nn.datatypes import MultipleData
1213
from brainpy.tools.checking import (check_shape_consistency,
1314
check_initializer)
1415
from brainpy.types import Tensor
@@ -40,6 +41,8 @@ class GeneralDense(Node):
4041
Enable training this node or not. (default True)
4142
"""
4243

44+
data_pass = MultipleData('sequence')
45+
4346
def __init__(
4447
self,
4548
num_unit: int,
@@ -123,6 +126,7 @@ class Dense(GeneralDense):
123126
trainable: bool
124127
Enable training this node or not. (default True)
125128
"""
129+
data_pass = MultipleData('sequence')
126130

127131
def __init__(
128132
self,

0 commit comments

Comments
 (0)