Skip to content

Commit 75f6fce

Browse files
committed
updates
1 parent d36273d commit 75f6fce

22 files changed

+647
-549
lines changed

brainpy/__init__.py

Lines changed: 356 additions & 356 deletions
Large diffs are not rendered by default.

brainpy/_src/connect/random_conn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def build_csr(self):
129129
def build_mat(self):
130130
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
131131
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
132-
mat = jnp.asarray(mat)
132+
mat = bm.asarray(mat)
133133
if not self.include_self:
134134
mat = bm.fill_diagonal(mat, False)
135135
return mat.astype(MAT_DTYPE)

brainpy/_src/dyn/layers/linear.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def __repr__(self):
8181
f'num_out={self.num_out}, '
8282
f'mode={self.mode})')
8383

84-
def update(self, sha, x):
84+
def update(self, *args):
85+
if len(args) == 1:
86+
sha, x = dict(), bm.as_jax(args[0])
87+
else:
88+
sha, x = args[0], bm.as_jax(args[1])
8589
res = x @ self.W
8690
if self.b is not None:
8791
res += self.b
@@ -102,7 +106,7 @@ def online_init(self):
102106
num_input = self.num_in
103107
else:
104108
num_input = self.num_in + 1
105-
self.online_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name)
109+
self.online_fit_by.register_target(feature_in=num_input, identifier=self.name)
106110

107111
def online_fit(self,
108112
target: ArrayType,
@@ -139,13 +143,6 @@ def online_fit(self,
139143
self.b += db[0]
140144
self.W += dW
141145

142-
def offline_init(self):
143-
if self.b is None:
144-
num_input = self.num_in + 1
145-
else:
146-
num_input = self.num_in
147-
self.offline_fit_by.initialize(feature_in=num_input, feature_out=self.num_out, identifier=self.name)
148-
149146
def offline_fit(self,
150147
target: ArrayType,
151148
fit_record: Dict[str, ArrayType]):
@@ -176,7 +173,7 @@ def offline_fit(self,
176173
xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input)
177174

178175
# solve weights by offline training methods
179-
weights = self.offline_fit_by(self.name, target, xs, ys)
176+
weights = self.offline_fit_by(target, xs, ys)
180177

181178
# assign trained weights
182179
if self.b is None:

brainpy/_src/dyn/layers/reservoir.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import brainpy.math as bm
88
from brainpy._src.initialize import Normal, ZeroInit, Initializer, parameter, variable
9-
from brainpy.check import is_float, is_initializer, is_string
9+
from brainpy import check
1010
from brainpy.tools import to_size
1111
from brainpy.types import ArrayType
1212
from .base import Layer
@@ -36,8 +36,9 @@ class Reservoir(Layer):
3636
A float between 0 and 1.
3737
activation : str, callable, optional
3838
Reservoir activation function.
39+
3940
- If a str, should be a :py:mod:`brainpy.math.activations` function name.
40-
- If a callable, should be an element-wise operator on tensor.
41+
- If a callable, should be an element-wise operator.
4142
activation_type : str
4243
- If "internal" (default), then leaky integration happens on states transformed
4344
by the activation function:
@@ -66,9 +67,12 @@ class Reservoir(Layer):
6667
neurons connected to other reservoir neurons, including themselves.
6768
Must be in [0, 1], by default 0.1
6869
comp_type: str
69-
The connectivity type, can be "dense" or "sparse".
70+
The connectivity type, can be "dense" or "sparse", "jit".
71+
72+
- ``"dense"`` means the connectivity matrix is a dense matrix.
73+
- ``"sparse"`` means the connectivity matrix is a CSR sparse matrix.
7074
spectral_radius : float, optional
71-
Spectral radius of recurrent weight matrix, by default None
75+
Spectral radius of recurrent weight matrix, by default None.
7276
noise_rec : float, optional
7377
Gain of noise applied to reservoir internal states, by default 0.0
7478
noise_in : float, optional
@@ -118,37 +122,38 @@ def __init__(
118122
self.num_unit = num_out
119123
assert num_out > 0, f'Must be a positive integer, but we got {num_out}'
120124
self.leaky_rate = leaky_rate
121-
is_float(leaky_rate, 'leaky_rate', 0., 1.)
122-
self.activation = getattr(bm.activations, activation)
125+
check.is_float(leaky_rate, 'leaky_rate', 0., 1.)
126+
self.activation = getattr(bm.activations, activation) if isinstance(activation, str) else activation
127+
check.is_callable(self.activation, allow_none=False)
123128
self.activation_type = activation_type
124-
is_string(activation_type, 'activation_type', ['internal', 'external'])
129+
check.is_string(activation_type, 'activation_type', ['internal', 'external'])
125130
self.rng = bm.random.default_rng(seed)
126-
is_float(spectral_radius, 'spectral_radius', allow_none=True)
131+
check.is_float(spectral_radius, 'spectral_radius', allow_none=True)
127132
self.spectral_radius = spectral_radius
128133

129134
# initializations
130-
is_initializer(Win_initializer, 'ff_initializer', allow_none=False)
131-
is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False)
132-
is_initializer(b_initializer, 'bias_initializer', allow_none=True)
135+
check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False)
136+
check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False)
137+
check.is_initializer(b_initializer, 'bias_initializer', allow_none=True)
133138
self._Win_initializer = Win_initializer
134139
self._Wrec_initializer = Wrec_initializer
135140
self._b_initializer = b_initializer
136141

137142
# connectivity
138-
is_float(in_connectivity, 'ff_connectivity', 0., 1.)
139-
is_float(rec_connectivity, 'rec_connectivity', 0., 1.)
143+
check.is_float(in_connectivity, 'ff_connectivity', 0., 1.)
144+
check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.)
140145
self.ff_connectivity = in_connectivity
141146
self.rec_connectivity = rec_connectivity
142-
is_string(comp_type, 'conn_type', ['dense', 'sparse'])
147+
check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit'])
143148
self.comp_type = comp_type
144149

145150
# noises
146-
is_float(noise_in, 'noise_ff')
147-
is_float(noise_rec, 'noise_rec')
151+
check.is_float(noise_in, 'noise_ff')
152+
check.is_float(noise_rec, 'noise_rec')
148153
self.noise_ff = noise_in
149154
self.noise_rec = noise_rec
150155
self.noise_type = noise_type
151-
is_string(noise_type, 'noise_type', ['normal', 'uniform'])
156+
check.is_string(noise_type, 'noise_type', ['normal', 'uniform'])
152157

153158
# initialize feedforward weights
154159
weight_shape = (input_shape[-1], self.num_unit)
@@ -170,7 +175,7 @@ def __init__(
170175
conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity
171176
self.Wrec[conn_mat] = 0.
172177
if self.spectral_radius is not None:
173-
current_sr = max(abs(jnp.linalg.eig(self.Wrec)[0]))
178+
current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0]))
174179
self.Wrec *= self.spectral_radius / current_sr
175180
if self.comp_type == 'sparse' and self.rec_connectivity < 1.:
176181
self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat)))
@@ -186,11 +191,13 @@ def __init__(
186191
def reset_state(self, batch_size=None):
187192
self.state.value = variable(jnp.zeros, batch_size, self.output_shape)
188193

189-
def update(self, sha, x):
194+
def update(self, *args):
190195
"""Feedforward output."""
191196
# inputs
192-
x = jnp.concatenate(x, axis=-1)
193-
if self.noise_ff > 0: x += self.noise_ff * self.rng.uniform(-1, 1, x.shape)
197+
x = args[0] if len(args) == 1 else args[1]
198+
x = bm.as_jax(x)
199+
if self.noise_ff > 0:
200+
x += self.noise_ff * self.rng.uniform(-1, 1, x.shape)
194201
if self.comp_type == 'sparse' and self.ff_connectivity < 1.:
195202
sparse = {'data': self.Win,
196203
'index': (self.ff_pres, self.ff_posts),

brainpy/_src/initialize/random_inits.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self, mean=0., scale=1., seed=None):
114114
def __call__(self, *shape, dtype=None):
115115
shape = _format_shape(shape)
116116
weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale)
117-
return jnp.asarray(weights, dtype=dtype)
117+
return bm.as_jax(weights, dtype=dtype)
118118

119119
def __repr__(self):
120120
return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})'
@@ -140,7 +140,7 @@ def __init__(self, min_val: float = 0., max_val: float = 1., seed=None):
140140
def __call__(self, shape, dtype=None):
141141
shape = _format_shape(shape)
142142
r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape)
143-
return jnp.asarray(r, dtype=dtype)
143+
return bm.as_jax(r, dtype=dtype)
144144

145145
def __repr__(self):
146146
return (f'{self.__class__.__name__}(min_val={self.min_val}, '
@@ -180,14 +180,14 @@ def __call__(self, shape, dtype=None):
180180
variance = (self.scale / denominator).astype(dtype)
181181
if self.distribution == "truncated_normal":
182182
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
183-
return self.rng.truncated_normal(-2, 2, shape, dtype) * stddev
183+
res = self.rng.truncated_normal(-2, 2, shape, dtype) * stddev
184184
elif self.distribution == "normal":
185185
res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype)
186186
elif self.distribution == "uniform":
187187
res = self.rng.uniform(low=-1, high=1, size=shape) * jnp.sqrt(3 * variance).astype(dtype)
188188
else:
189189
raise ValueError("invalid distribution for variance scaling initializer")
190-
return jnp.asarray(res, dtype=dtype)
190+
return bm.as_jax(res, dtype=dtype)
191191

192192
def __repr__(self):
193193
name = self.__class__.__name__
@@ -336,7 +336,7 @@ def __call__(self, shape, dtype=None):
336336
q_mat = q_mat.T
337337
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
338338
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
339-
return self.scale * jnp.asarray(q_mat, dtype=dtype)
339+
return self.scale * bm.as_jax(q_mat, dtype=dtype)
340340

341341
def __repr__(self):
342342
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, rng={self.rng})'

brainpy/_src/optimizers/tests/test_scheduler.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -30,70 +30,70 @@ def test2(self, last_epoch):
3030
self.assertTrue(lr1 == lr2)
3131

3232

33-
class TestStepLR(parameterized.TestCase):
34-
35-
@parameterized.named_parameters(
36-
{'testcase_name': f'last_epoch={last_epoch}',
37-
'last_epoch': last_epoch}
38-
for last_epoch in [-1, 0, 5, 10]
39-
)
40-
def test1(self, last_epoch):
41-
scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
42-
scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
43-
44-
for i in range(1, 25):
45-
lr1 = scheduler1(i + last_epoch)
46-
lr2 = scheduler2()
47-
scheduler2.step_epoch()
48-
print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}')
49-
self.assertTrue(lr1 == lr2)
50-
51-
52-
class TestCosineAnnealingLR(unittest.TestCase):
53-
def test1(self):
54-
max_epoch = 50
55-
iters = 200
56-
sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1)
57-
all_lr1 = [[], []]
58-
all_lr2 = [[], []]
59-
for epoch in range(max_epoch):
60-
for batch in range(iters):
61-
all_lr1[0].append(epoch + batch / iters)
62-
all_lr1[1].append(sch())
63-
sch.step_epoch()
64-
all_lr2[0].append(epoch)
65-
all_lr2[1].append(sch())
66-
sch.step_epoch()
67-
plt.subplot(211)
68-
plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1]))
69-
plt.subplot(212)
70-
plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1]))
71-
plt.show()
72-
plt.close()
73-
74-
75-
class TestCosineAnnealingWarmRestarts(unittest.TestCase):
76-
def test1(self):
77-
max_epoch = 50
78-
iters = 200
79-
sch = scheduler.CosineAnnealingWarmRestarts(0.1,
80-
iters,
81-
T_0=5,
82-
T_mult=1,
83-
last_call=-1)
84-
all_lr1 = []
85-
all_lr2 = []
86-
for epoch in range(max_epoch):
87-
for batch in range(iters):
88-
all_lr1.append(sch())
89-
sch.step_call()
90-
all_lr2.append(sch())
91-
sch.step_epoch()
92-
plt.subplot(211)
93-
plt.plot(jax.numpy.asarray(all_lr1))
94-
plt.subplot(212)
95-
plt.plot(jax.numpy.asarray(all_lr2))
96-
plt.show()
97-
plt.close()
33+
# class TestStepLR(parameterized.TestCase):
34+
#
35+
# @parameterized.named_parameters(
36+
# {'testcase_name': f'last_epoch={last_epoch}',
37+
# 'last_epoch': last_epoch}
38+
# for last_epoch in [-1, 0, 5, 10]
39+
# )
40+
# def test1(self, last_epoch):
41+
# scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
42+
# scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
43+
#
44+
# for i in range(1, 25):
45+
# lr1 = scheduler1(i + last_epoch)
46+
# lr2 = scheduler2()
47+
# scheduler2.step_epoch()
48+
# print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}')
49+
# self.assertTrue(lr1 == lr2)
50+
#
51+
#
52+
# class TestCosineAnnealingLR(unittest.TestCase):
53+
# def test1(self):
54+
# max_epoch = 50
55+
# iters = 200
56+
# sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1)
57+
# all_lr1 = [[], []]
58+
# all_lr2 = [[], []]
59+
# for epoch in range(max_epoch):
60+
# for batch in range(iters):
61+
# all_lr1[0].append(epoch + batch / iters)
62+
# all_lr1[1].append(sch())
63+
# sch.step_epoch()
64+
# all_lr2[0].append(epoch)
65+
# all_lr2[1].append(sch())
66+
# sch.step_epoch()
67+
# plt.subplot(211)
68+
# plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1]))
69+
# plt.subplot(212)
70+
# plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1]))
71+
# plt.show()
72+
# plt.close()
73+
#
74+
#
75+
# class TestCosineAnnealingWarmRestarts(unittest.TestCase):
76+
# def test1(self):
77+
# max_epoch = 50
78+
# iters = 200
79+
# sch = scheduler.CosineAnnealingWarmRestarts(0.1,
80+
# iters,
81+
# T_0=5,
82+
# T_mult=1,
83+
# last_call=-1)
84+
# all_lr1 = []
85+
# all_lr2 = []
86+
# for epoch in range(max_epoch):
87+
# for batch in range(iters):
88+
# all_lr1.append(sch())
89+
# sch.step_call()
90+
# all_lr2.append(sch())
91+
# sch.step_epoch()
92+
# plt.subplot(211)
93+
# plt.plot(jax.numpy.asarray(all_lr1))
94+
# plt.subplot(212)
95+
# plt.plot(jax.numpy.asarray(all_lr2))
96+
# plt.show()
97+
# plt.close()
9898

9999

0 commit comments

Comments
 (0)