Skip to content

Commit b9bce46

Browse files
authored
add Activation and Flatten class (#291)
add Activation and Flatten class
2 parents 1e9c8c2 + 9698606 commit b9bce46

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

brainpy/dyn/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
from .rnncells import *
88
from .conv import *
99
from .normalization import *
10-
from .pooling import *
10+
from .pooling import *
11+
from .activate import *

brainpy/dyn/layers/activate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from brainpy.dyn.base import DynamicalSystem
2+
from typing import Optional
3+
from brainpy.modes import Mode
4+
from typing import Callable
5+
6+
7+
class Activation(DynamicalSystem):
8+
r"""Applies a activation to the inputs
9+
10+
Parameters:
11+
----------
12+
activate_fun: Callable
13+
The function of Activation
14+
name: str, Optional
15+
The name of the object
16+
mode: Mode
17+
Enable training this node or not. (default True).
18+
"""
19+
20+
def __init__(self,
21+
activate_fun: Callable,
22+
name: Optional[str] = None,
23+
mode: Optional[Mode] = None,
24+
**kwargs,
25+
):
26+
super().__init__(name, mode)
27+
self.activate_fun = activate_fun
28+
self.kwargs = kwargs
29+
30+
def update(self, sha, x):
31+
return self.activate_fun(x, **self.kwargs)
32+
33+
def reset_state(self, batch_size=None):
34+
pass

brainpy/dyn/layers/linear.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from brainpy.dyn.base import DynamicalSystem
1010
from brainpy.errors import MathError
1111
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
12-
from brainpy.modes import Mode, TrainingMode, training
12+
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
1313
from brainpy.tools.checking import check_initializer
1414
from brainpy.types import Array
1515

1616
__all__ = [
1717
'Dense',
18+
'Flatten'
1819
]
1920

2021

@@ -188,3 +189,29 @@ def offline_fit(self,
188189
bias, Wff = bm.split(weights, [1])
189190
self.W.value = Wff
190191
self.b.value = bias[0]
192+
193+
194+
class Flatten(DynamicalSystem):
195+
r"""Flattens a contiguous range of dims into 2D or 1D.
196+
197+
Parameters:
198+
----------
199+
name: str, Optional
200+
The name of the object
201+
mode: Mode
202+
Enable training this node or not. (default True)
203+
"""
204+
def __init__(self,
205+
name: Optional[str] = None,
206+
mode: Optional[Mode] = batching,
207+
):
208+
super().__init__(name, mode)
209+
210+
def update(self, shr, x):
211+
if isinstance(self.mode, BatchingMode):
212+
return x.reshape((x.shape[0], -1))
213+
else:
214+
return x.flatten()
215+
216+
def reset_state(self, batch_size=None):
217+
pass

0 commit comments

Comments
 (0)