Skip to content

Commit 3086c69

Browse files
authored
Merge pull request #119 from PKU-NIP-Lab/whole-brain-modeling
fix bugs
2 parents d2f9254 + ed4e5e5 commit 3086c69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+838
-412
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,18 @@ runner.run(100.)
147147

148148

149149

150-
Numerical methods for delay differential equations (SDEs).
150+
Numerical methods for delay differential equations (SDEs).
151151

152152
```python
153-
xdelay = bm.FixedLenDelay(1, delay_len=1., before_t0=1., dt=0.01)
153+
xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01)
154+
154155

155156
@bp.ddeint(method='rk4', state_delays={'x': xdelay})
156157
def second_order_eq(x, y, t):
157-
dx = y
158-
dy = -y - 2*x - 0.5*xdelay(t-1)
159-
return dx, dy
158+
dx = y
159+
dy = -y - 2 * x - 0.5 * xdelay(t - 1)
160+
return dx, dy
161+
160162

161163
runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01)
162164
runner.run(100.)

brainpy/analysis/utils/measurement.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import jax.numpy as jnp
44
import numpy as np
5+
from brainpy.tools.others import numba_jit
56

67

78
__all__ = [
@@ -10,7 +11,7 @@
1011
]
1112

1213

13-
# @tools.numba_jit
14+
@numba_jit
1415
def _f1(arr, grad, tol):
1516
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
1617
indexes = np.where(condition)[0]
@@ -19,7 +20,8 @@ def _f1(arr, grad, tol):
1920
length = np.max(data) - np.min(data)
2021
a = arr[indexes[-2]]
2122
b = arr[indexes[-1]]
22-
if np.abs(a - b) <= tol * length:
23+
# TODO: how to choose length threshold, 1e-3?
24+
if length > 1e-3 and np.abs(a - b) <= tol * length:
2325
return indexes[-2:]
2426
return np.array([-1, -1])
2527

brainpy/analysis/utils/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def model_transform(model):
4949
new_model = []
5050
for intg in model:
5151
if isinstance(intg.f, JointEq):
52-
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt, dyn_var=intg.dyn_var)
53-
for eq in intg.f.eqs])
52+
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs])
5453
else:
5554
new_model.append(intg)
5655

brainpy/check.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
__all__ = [
5+
'is_checking',
6+
'turn_on',
7+
'turn_off',
8+
]
9+
10+
_check = True
11+
12+
13+
def is_checking():
14+
"""Whether the checking is turn on."""
15+
return _check
16+
17+
18+
def turn_on():
19+
"""Turn on the checking."""
20+
global _check
21+
_check = True
22+
23+
24+
def turn_off():
25+
"""Turn off the checking."""
26+
global _check
27+
_check = False

brainpy/datasets/chaotic_systems.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65,
167167
assert isinstance(inits, (bm.ndarray, jnp.ndarray))
168168

169169
rng = bm.random.RandomState(seed)
170-
xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt)
170+
xdelay = bm.TimeDelay(inits.shape, tau, dt=dt)
171171
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5)
172172

173173
@ddeint(method=method, state_delays={'x': xdelay})

brainpy/dyn/neurons/rate_models.py

Lines changed: 6 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
__all__ = [
1111
'FHN',
1212
'FeedbackFHN',
13-
'MeanFieldQIF',
1413
]
1514

1615

@@ -197,7 +196,6 @@ def __init__(self,
197196
tau: Parameter = 12.5,
198197
mu: Parameter = 1.6886,
199198
v0: Parameter = -1,
200-
Vth: Parameter = 1.8,
201199
method: str = 'rk4',
202200
name: str = None):
203201
super(FeedbackFHN, self).__init__(size=size, name=name)
@@ -209,23 +207,21 @@ def __init__(self,
209207
self.tau = tau
210208
self.mu = mu # feedback strength
211209
self.v0 = v0 # resting potential
212-
self.Vth = Vth
213210

214211
# variables
215212
self.w = bm.Variable(bm.zeros(self.num))
216213
self.V = bm.Variable(bm.zeros(self.num))
217-
self.Vdelay = bm.FixedLenDelay(self.num, self.delay)
214+
self.Vdelay = bm.TimeDelay(self.num, self.delay, interp_method='round')
218215
self.input = bm.Variable(bm.zeros(self.num))
219-
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
220-
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
221216

222217
# integral
223-
self.integral = ddeint(method=method, f=self.derivative,
218+
self.integral = ddeint(method=method,
219+
f=self.derivative,
224220
state_delays={'V': self.Vdelay})
225221

226-
def dV(self, V, t, w, Vdelay):
222+
def dV(self, V, t, w):
227223
return (V - V * V * V / 3 - w + self.input +
228-
self.mu * (Vdelay(t - self.delay) - self.v0))
224+
self.mu * (self.Vdelay(t - self.delay) - self.v0))
229225

230226
def dw(self, w, t, V):
231227
return (V + self.a - self.b * w) / self.tau
@@ -235,129 +231,7 @@ def derivative(self):
235231
return JointEq([self.dV, self.dw])
236232

237233
def update(self, _t, _dt):
238-
V, w = self.integral(self.V, self.w, _t, Vdelay=self.Vdelay, dt=_dt)
239-
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
240-
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
234+
V, w = self.integral(self.V, self.w, _t, dt=_dt)
241235
self.V.value = V
242236
self.w.value = w
243237
self.input[:] = 0.
244-
245-
246-
class MeanFieldQIF(NeuGroup):
247-
r"""A mean-field model of a quadratic integrate-and-fire neuron population.
248-
249-
**Model Descriptions**
250-
251-
The QIF population mean-field model, which has been derived from a
252-
population of all-to-all coupled QIF neurons in [5]_.
253-
The model equations are given by:
254-
255-
.. math::
256-
257-
\begin{aligned}
258-
\tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\
259-
\tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2}
260-
\end{aligned}
261-
262-
where :math:`r` is the average firing rate and :math:`v` is the
263-
average membrane potential of the QIF population [5]_.
264-
265-
This mean-field model is an exact representation of the macroscopic
266-
firing rate and membrane potential dynamics of a spiking neural network
267-
consisting of QIF neurons with Lorentzian distributed background
268-
excitabilities. While the mean-field derivation is mathematically
269-
only valid for all-to-all coupled populations of infinite size, it
270-
has been shown that there is a close correspondence between the
271-
mean-field model and neural populations with sparse coupling and
272-
population sizes of a few thousand neurons [6]_.
273-
274-
**Model Parameters**
275-
276-
============= ============== ======== ========================
277-
**Parameter** **Init Value** **Unit** **Explanation**
278-
------------- -------------- -------- ------------------------
279-
tau 1 ms the population time constant
280-
eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population
281-
delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability
282-
J 15 \ the strength of the recurrent coupling inside the population
283-
============= ============== ======== ========================
284-
285-
286-
References
287-
----------
288-
.. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for
289-
networks of spiking neurons. Physical Review X, 5:021028,
290-
https://doi.org/10.1103/PhysRevX.5.021028.
291-
.. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description
292-
of Bursting Dynamics in Spiking Neural Networks with Short-Term
293-
Adaptation. Neural Computation 32.9 (2020): 1615-1634.
294-
295-
"""
296-
297-
def __init__(self,
298-
size: Shape,
299-
tau: Parameter = 1.,
300-
eta: Parameter = -5.0,
301-
delta: Parameter = 1.0,
302-
J: Parameter = 15.,
303-
method: str = 'exp_auto',
304-
name: str = None):
305-
super(MeanFieldQIF, self).__init__(size=size, name=name)
306-
307-
# parameters
308-
self.tau = tau #
309-
self.eta = eta # the mean of a Lorenzian distribution over the neural excitability in the population
310-
self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability
311-
self.J = J # the strength of the recurrent coupling inside the population
312-
313-
# variables
314-
self.r = bm.Variable(bm.ones(1))
315-
self.V = bm.Variable(bm.ones(1))
316-
self.input = bm.Variable(bm.zeros(1))
317-
318-
# functions
319-
self.integral = odeint(self.derivative, method=method)
320-
321-
def dr(self, r, t, v):
322-
return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau
323-
324-
def dV(self, v, t, r):
325-
return (v ** 2 + self.eta + self.input + self.J * r * self.tau -
326-
(bm.pi * r * self.tau) ** 2) / self.tau
327-
328-
@property
329-
def derivative(self):
330-
return JointEq([self.dV, self.dr])
331-
332-
def update(self, _t, _dt):
333-
self.V.value, self.r.value = self.integral(self.V, self.r, _t, _dt)
334-
self.integral[:] = 0.
335-
336-
337-
338-
class VanDerPolOscillator(NeuGroup):
339-
pass
340-
341-
342-
class ThetaNeuron(NeuGroup):
343-
pass
344-
345-
346-
class MeanFieldQIFWithSFA(NeuGroup):
347-
pass
348-
349-
350-
class JansenRitModel(NeuGroup):
351-
pass
352-
353-
354-
class WilsonCowanModel(NeuGroup):
355-
pass
356-
357-
class StuartLandauOscillator(NeuGroup):
358-
pass
359-
360-
361-
class KuramotoOscillator(NeuGroup):
362-
pass
363-

brainpy/dyn/rates/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from .base import *
4+
from .fhn import *

brainpy/dyn/rates/base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from brainpy.dyn.base import DynamicalSystem
4+
from brainpy.tools.others import to_size, size2num
5+
from brainpy.types import Shape
6+
7+
__all__ = [
8+
'RateModel',
9+
]
10+
11+
12+
class RateModel(DynamicalSystem):
13+
"""Base class of rate models."""
14+
15+
def __init__(self,
16+
size: Shape,
17+
name: str = None):
18+
super(RateModel, self).__init__(name=name)
19+
20+
self.size = to_size(size)
21+
self.num = size2num(self.size)
22+
23+
def update(self, _t, _dt):
24+
"""The function to specify the updating rule.
25+
26+
Parameters
27+
----------
28+
_t : float
29+
The current time.
30+
_dt : float
31+
The time step.
32+
"""
33+
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
34+
f'implement "update" function.')

0 commit comments

Comments
 (0)