Skip to content

Commit 4a525ef

Browse files
authored
Multiple functionalinaty improvements (#197)
Functionalinaty improvements
2 parents 11083a0 + 4e6f7ac commit 4a525ef

File tree

24 files changed

+330
-167
lines changed

24 files changed

+330
-167
lines changed

brainpy/dyn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
from .utils import *
1313
from .runners import *
1414

15-
from . import neurons, synapses, channels, utils, runners
15+
from . import neurons, synapses, channels, rates, utils, runners

brainpy/dyn/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,14 +502,17 @@ def __init__(self,
502502
# size
503503
if isinstance(size, (list, tuple)):
504504
if len(size) <= 0:
505-
raise ModelBuildError('size must be int, or a tuple/list of int.')
505+
raise ModelBuildError(f'size must be int, or a tuple/list of int. '
506+
f'But we got {type(size)}')
506507
if not isinstance(size[0], int):
507-
raise ModelBuildError('size must be int, or a tuple/list of int.')
508+
raise ModelBuildError('size must be int, or a tuple/list of int.'
509+
f'But we got {type(size)}')
508510
size = tuple(size)
509511
elif isinstance(size, int):
510512
size = (size,)
511513
else:
512-
raise ModelBuildError('size must be int, or a tuple/list of int.')
514+
raise ModelBuildError('size must be int, or a tuple/list of int.'
515+
f'But we got {type(size)}')
513516
self.size = size
514517
self.num = tools.size2num(size)
515518

brainpy/dyn/neurons/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,4 @@
33
from .biological_models import *
44
from .fractional_models import *
55
from .input_models import *
6-
from .noise_models import *
7-
from .rate_models import *
86
from .reduced_models import *

brainpy/dyn/neurons/reduced_models.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import brainpy.math as bm
66
from brainpy.dyn.base import NeuGroup
77
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param
8-
from brainpy.integrators.joint_eq import JointEq
9-
from brainpy.integrators.ode import odeint
8+
from brainpy.integrators import sdeint, odeint, JointEq
109
from brainpy.tools.checking import check_initializer
1110
from brainpy.types import Shape, Tensor
1211

@@ -46,31 +45,34 @@ class LIF(NeuGroup):
4645
4746
- `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_
4847
49-
**Model Parameters**
50-
51-
============= ============== ======== =========================================
52-
**Parameter** **Init Value** **Unit** **Explanation**
53-
------------- -------------- -------- -----------------------------------------
54-
V_rest 0 mV Resting membrane potential.
55-
V_reset -5 mV Reset potential after spike.
56-
V_th 20 mV Threshold potential of spike.
57-
tau 10 ms Membrane time constant. Compute by R * C.
58-
tau_ref 5 ms Refractory period length.(ms)
59-
============= ============== ======== =========================================
6048
61-
**Neuron Variables**
62-
63-
================== ================= =========================================================
64-
**Variables name** **Initial Value** **Explanation**
65-
------------------ ----------------- ---------------------------------------------------------
66-
V 0 Membrane potential.
67-
input 0 External and synaptic input current.
68-
spike False Flag to mark whether the neuron is spiking.
69-
refractory False Flag to mark whether the neuron is in refractory period.
70-
t_last_spike -1e7 Last spike time stamp.
71-
================== ================= =========================================================
72-
73-
**References**
49+
Parameters
50+
----------
51+
size: sequence of int, int
52+
The size of the neuron group.
53+
V_rest: float, JaxArray, ndarray, Initializer, callable
54+
Resting membrane potential.
55+
V_reset: float, JaxArray, ndarray, Initializer, callable
56+
Reset potential after spike.
57+
V_th: float, JaxArray, ndarray, Initializer, callable
58+
Threshold potential of spike.
59+
tau: float, JaxArray, ndarray, Initializer, callable
60+
Membrane time constant.
61+
tau_ref: float, JaxArray, ndarray, Initializer, callable
62+
Refractory period length.(ms)
63+
V_initializer: JaxArray, ndarray, Initializer, callable
64+
The initializer of membrane potential.
65+
noise: JaxArray, ndarray, Initializer, callable
66+
The noise added onto the membrane potential
67+
noise_type: str
68+
The type of the provided noise. Can be `value` or `func`.
69+
method: str
70+
The numerical integration method.
71+
name: str
72+
The group name.
73+
74+
References
75+
----------
7476
7577
.. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model
7678
neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304.
@@ -85,44 +87,57 @@ def __init__(
8587
tau: Union[float, Tensor, Initializer, Callable] = 10.,
8688
tau_ref: Union[float, Tensor, Initializer, Callable] = 1.,
8789
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(),
90+
noise: Union[float, Tensor, Initializer, Callable] = None,
91+
noise_type: str = 'value',
92+
keep_size: bool=False,
8893
method: str = 'exp_auto',
8994
name: str = None
9095
):
9196
# initialization
9297
super(LIF, self).__init__(size=size, name=name)
9398

9499
# parameters
95-
self.V_rest = init_param(V_rest, self.num, allow_none=False)
96-
self.V_reset = init_param(V_reset, self.num, allow_none=False)
97-
self.V_th = init_param(V_th, self.num, allow_none=False)
98-
self.tau = init_param(tau, self.num, allow_none=False)
99-
self.tau_ref = init_param(tau_ref, self.num, allow_none=False)
100+
self.keep_size = keep_size
101+
self.noise_type = noise_type
102+
if noise_type not in ['func', 'value']:
103+
raise ValueError(f'noise_type only supports `func` and `value`, but we got {noise_type}')
104+
size = self.size if keep_size else self.num
105+
self.V_rest = init_param(V_rest, size, allow_none=False)
106+
self.V_reset = init_param(V_reset, size, allow_none=False)
107+
self.V_th = init_param(V_th, size, allow_none=False)
108+
self.tau = init_param(tau, size, allow_none=False)
109+
self.tau_ref = init_param(tau_ref, size, allow_none=False)
110+
if noise_type == 'func':
111+
self.noise = noise
112+
else:
113+
self.noise = init_param(noise, size, allow_none=True)
100114

101115
# initializers
102116
check_initializer(V_initializer, 'V_initializer')
103117
self._V_initializer = V_initializer
104118

105119
# variables
106-
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
107-
self.input = bm.Variable(bm.zeros(self.num))
108-
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
109-
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
110-
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool))
120+
self.V = bm.Variable(init_param(V_initializer, size))
121+
self.input = bm.Variable(bm.zeros(size))
122+
self.spike = bm.Variable(bm.zeros(size, dtype=bool))
123+
self.t_last_spike = bm.Variable(bm.ones(size) * -1e7)
124+
self.refractory = bm.Variable(bm.zeros(size, dtype=bool))
111125

112126
# integral
113-
self.integral = odeint(method=method, f=self.derivative)
127+
f = lambda V, t, I_ext: (-V + self.V_rest + I_ext) / self.tau
128+
if self.noise is not None:
129+
g = noise if (noise_type == 'func') else (lambda V, t, I_ext: self.noise / bm.sqrt(self.tau))
130+
self.integral = sdeint(method=method, f=f, g=g)
131+
else:
132+
self.integral = odeint(method=method, f=f)
114133

115134
def reset(self):
116-
self.V.value = init_param(self._V_initializer, (self.num,))
135+
self.V.value = init_param(self._V_initializer, self.size if self.keep_size else self.num)
117136
self.input[:] = 0
118137
self.spike[:] = False
119138
self.t_last_spike[:] = -1e7
120139
self.refractory[:] = False
121140

122-
def derivative(self, V, t, I_ext):
123-
dvdt = (-V + self.V_rest + I_ext) / self.tau
124-
return dvdt
125-
126141
def update(self, t, dt):
127142
refractory = (t - self.t_last_spike) <= self.tau_ref
128143
V = self.integral(self.V, t, self.input, dt=dt)

brainpy/dyn/rates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from .noises import *
4+
from .populations import *
5+
from .couplings import *
6+

brainpy/dyn/synapses/delay_coupling.py renamed to brainpy/dyn/rates/couplings.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class DelayCoupling(DynamicalSystem):
2525
----------
2626
delay_var: Variable
2727
The delay variable.
28-
output_var: Variable, sequence of Variable
28+
target_var: Variable, sequence of Variable
2929
The target variables to output.
3030
conn_mat: JaxArray, ndarray
3131
The connection matrix.
@@ -40,7 +40,7 @@ class DelayCoupling(DynamicalSystem):
4040
def __init__(
4141
self,
4242
delay_var: bm.Variable,
43-
output_var: Union[bm.Variable, Sequence[bm.Variable]],
43+
target_var: Union[bm.Variable, Sequence[bm.Variable]],
4444
conn_mat: Tensor,
4545
required_shape: Tuple[int, ...],
4646
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
@@ -56,10 +56,10 @@ def __init__(
5656
self.delay_var = delay_var
5757

5858
# output variables
59-
if isinstance(output_var, bm.Variable):
60-
output_var = [output_var]
61-
check_sequence(output_var, 'output_var', elem_type=bm.Variable, allow_none=False)
62-
self.output_var = output_var
59+
if isinstance(target_var, bm.Variable):
60+
target_var = [target_var]
61+
check_sequence(target_var, 'output_var', elem_type=bm.Variable, allow_none=False)
62+
self.output_var = target_var
6363

6464
# Connection matrix
6565
self.conn_mat = bm.asarray(conn_mat)
@@ -122,8 +122,9 @@ class DiffusiveCoupling(DelayCoupling):
122122
--------
123123
124124
>>> import brainpy as bp
125-
>>> areas = bp.dyn.RateFHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
126-
>>> conn = bp.dyn.DiffusiveCoupling(areas.x, areas.x, areas.input,
125+
>>> from brainpy.dyn import rates
126+
>>> areas = rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
127+
>>> conn = rates.DiffusiveCoupling(areas.x, areas.x, areas.input,
127128
>>> conn_mat=Cmat, delay_steps=Dmat,
128129
>>> initial_delay_data=bp.init.Uniform(0, 0.05))
129130
>>> net = bp.dyn.Network(areas, conn)
@@ -134,7 +135,7 @@ class DiffusiveCoupling(DelayCoupling):
134135
The first coupling variable, used for delay.
135136
coupling_var2: Variable
136137
Another coupling variable.
137-
output_var: Variable, sequence of Variable
138+
target_var: Variable, sequence of Variable
138139
The target variables to output.
139140
conn_mat: JaxArray, ndarray
140141
The connection matrix.
@@ -150,7 +151,7 @@ def __init__(
150151
self,
151152
coupling_var1: bm.Variable,
152153
coupling_var2: bm.Variable,
153-
output_var: Union[bm.Variable, Sequence[bm.Variable]],
154+
target_var: Union[bm.Variable, Sequence[bm.Variable]],
154155
conn_mat: Tensor,
155156
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
156157
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
@@ -171,7 +172,7 @@ def __init__(
171172

172173
super(DiffusiveCoupling, self).__init__(
173174
delay_var=coupling_var1,
174-
output_var=output_var,
175+
target_var=target_var,
175176
conn_mat=conn_mat,
176177
required_shape=(coupling_var1.size, coupling_var2.size),
177178
delay_steps=delay_steps,
@@ -190,15 +191,18 @@ def update(self, t, dt):
190191
# delays
191192
if self.delay_type == 'none':
192193
diffusive = bm.expand_dims(self.coupling_var1, axis=1) - self.coupling_var2
194+
diffusive = (self.conn_mat * diffusive).sum(axis=0)
193195
elif self.delay_type == 'array':
194196
f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var1.size))) # (pre.num,)
195197
delays = f(bm.arange(self.coupling_var2.size).value)
196198
diffusive = delays.T - self.coupling_var2 # (post.num, pre.num)
199+
diffusive = (self.conn_mat * diffusive).sum(axis=0)
197200
elif self.delay_type == 'int':
198-
diffusive = bm.expand_dims(delay_var(self.delay_steps), axis=1) - self.coupling_var2
201+
delayed_var = delay_var(self.delay_steps)
202+
diffusive = bm.expand_dims(delayed_var, axis=1) - self.coupling_var2
203+
diffusive = (self.conn_mat * diffusive).sum(axis=0)
199204
else:
200205
raise ValueError
201-
diffusive = (self.conn_mat * diffusive).sum(axis=0)
202206

203207
# output to target variable
204208
for target in self.output_var:
@@ -221,7 +225,7 @@ class AdditiveCoupling(DelayCoupling):
221225
----------
222226
coupling_var: Variable
223227
The coupling variable, used for delay.
224-
output_var: Variable, sequence of Variable
228+
target_var: Variable, sequence of Variable
225229
The target variables to output.
226230
conn_mat: JaxArray, ndarray
227231
The connection matrix.
@@ -236,7 +240,7 @@ class AdditiveCoupling(DelayCoupling):
236240
def __init__(
237241
self,
238242
coupling_var: bm.Variable,
239-
output_var: Union[bm.Variable, Sequence[bm.Variable]],
243+
target_var: Union[bm.Variable, Sequence[bm.Variable]],
240244
conn_mat: Tensor,
241245
delay_steps: Optional[Union[int, Tensor, Initializer, Callable]] = None,
242246
initial_delay_data: Union[Initializer, Callable, Tensor, float, int, bool] = None,
@@ -251,7 +255,7 @@ def __init__(
251255

252256
super(AdditiveCoupling, self).__init__(
253257
delay_var=coupling_var,
254-
output_var=output_var,
258+
target_var=target_var,
255259
conn_mat=conn_mat,
256260
required_shape=(coupling_var.size, coupling_var.size),
257261
delay_steps=delay_steps,
@@ -267,7 +271,7 @@ def update(self, t, dt):
267271

268272
# delay function
269273
if self.delay_steps is None:
270-
additive = self.conn_mat * bm.expand_dims(self.coupling_var, axis=1)
274+
additive = self.coupling_var @ self.conn_mat
271275
else:
272276
f = vmap(lambda i: delay_var(self.delay_steps[i], bm.arange(self.coupling_var.size))) # (pre.num,)
273277
delays = f(bm.arange(self.coupling_var.size).value) # (post.num, pre.num)
File renamed without changes.

0 commit comments

Comments
 (0)