Skip to content

Commit 5281ef7

Browse files
committed
update apis
1 parent e403b37 commit 5281ef7

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

brainpy/dyn/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ def _init_weights(
996996
) -> Union[float, Array]:
997997
if comp_method not in ['sparse', 'dense']:
998998
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
999-
if sparse_data not in ['csr', 'ij']:
999+
if sparse_data not in ['csr', 'ij', 'coo']:
10001000
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
10011001
if self.conn is None:
10021002
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
@@ -1014,7 +1014,7 @@ def _init_weights(
10141014
if comp_method == 'sparse':
10151015
if sparse_data == 'csr':
10161016
conn_mask = self.conn.require('pre2post')
1017-
elif sparse_data == 'ij':
1017+
elif sparse_data in ['ij', 'coo']:
10181018
conn_mask = self.conn.require('post_ids', 'pre_ids')
10191019
else:
10201020
ValueError(f'Unknown sparse data type: {sparse_data}')

brainpy/dyn/layers/activate.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
1-
from brainpy.dyn.base import DynamicalSystem
2-
from typing import Optional
3-
from brainpy.modes import Mode
41
from typing import Callable
2+
from typing import Optional
3+
4+
from brainpy.dyn.base import DynamicalSystem
5+
from brainpy.modes import Mode, training
56

67

78
class Activation(DynamicalSystem):
8-
r"""Applies a activation to the inputs
9+
r"""Applies an activation function to the inputs
910
1011
Parameters:
1112
----------
12-
activate_fun: Callable
13+
activate_fun: Callable, function
1314
The function of Activation
1415
name: str, Optional
1516
The name of the object
1617
mode: Mode
1718
Enable training this node or not. (default True).
1819
"""
1920

20-
def __init__(self,
21-
activate_fun: Callable,
22-
name: Optional[str] = None,
23-
mode: Optional[Mode] = None,
24-
**kwargs,
25-
):
21+
def __init__(
22+
self,
23+
activate_fun: Callable,
24+
name: Optional[str] = None,
25+
mode: Mode = training,
26+
**kwargs,
27+
):
2628
super().__init__(name, mode)
2729
self.activate_fun = activate_fun
2830
self.kwargs = kwargs

brainpy/dyn/layers/linear.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from brainpy.dyn.base import DynamicalSystem
1010
from brainpy.errors import MathError
1111
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
12-
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
12+
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
1313
from brainpy.tools.checking import check_initializer
1414
from brainpy.types import Array
1515

@@ -201,17 +201,19 @@ class Flatten(DynamicalSystem):
201201
mode: Mode
202202
Enable training this node or not. (default True)
203203
"""
204-
def __init__(self,
205-
name: Optional[str] = None,
206-
mode: Optional[Mode] = batching,
207-
):
204+
205+
def __init__(
206+
self,
207+
name: Optional[str] = None,
208+
mode: Optional[Mode] = batching,
209+
):
208210
super().__init__(name, mode)
209-
211+
210212
def update(self, shr, x):
211213
if isinstance(self.mode, BatchingMode):
212214
return x.reshape((x.shape[0], -1))
213215
else:
214216
return x.flatten()
215-
217+
216218
def reset_state(self, batch_size=None):
217-
pass
219+
pass

0 commit comments

Comments
 (0)