Skip to content

Commit f81498b

Browse files
authored
Support reset function in neuron and synapse models (#181)
Support reset functions in neuron and synapse models
2 parents 8d1e965 + 97a1cbb commit f81498b

File tree

15 files changed

+603
-278
lines changed

15 files changed

+603
-278
lines changed

brainpy/dyn/base.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def update(self, _t, _dt):
8686
Assume any dynamical system depends on the time variable ``t`` and
8787
the time step ``dt``.
8888
"""
89-
raise NotImplementedError('Must implement "update" function by user self.')
89+
raise NotImplementedError('Must implement "update" function by subclass self.')
90+
91+
def reset(self):
92+
"""Reset function which reset the whole variables in the model.
93+
"""
94+
raise NotImplementedError('Must implement "reset" function by subclass self.')
9095

9196

9297
class Container(DynamicalSystem):
@@ -143,6 +148,17 @@ def __getattr__(self, item):
143148
else:
144149
return super(Container, self).__getattribute__(item)
145150

151+
def reset(self):
152+
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
153+
neuron_groups = nodes.subset(NeuGroup)
154+
synapse_groups = nodes.subset(TwoEndConn)
155+
for node in neuron_groups.values():
156+
node.reset()
157+
for node in synapse_groups.values():
158+
node.reset()
159+
for node in (nodes - neuron_groups - synapse_groups).values():
160+
node.reset()
161+
146162

147163
class Network(Container):
148164
"""Base class to model network objects, an alias of Container.
@@ -196,6 +212,7 @@ class can automatically support your batched data.
196212
def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
197213
# dt
198214
self.dt = bm.get_dt() if dt is None else dt
215+
self.dtype = dtype
199216

200217
# data size
201218
if isinstance(size, int): size = (size,)
@@ -213,6 +230,7 @@ def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
213230
self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32))
214231
self.in_idx = bm.Variable(bm.array([self.num_step - 1], dtype=bm.uint32))
215232
self.data = bm.Variable(bm.zeros((self.num_step,) + self.size, dtype=dtype))
233+
self.num = 1
216234

217235
else: # non-uniform delay
218236
self.uniform_delay = False
@@ -237,6 +255,12 @@ def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
237255

238256
super(ConstantDelay, self).__init__(**kwargs)
239257

258+
def reset(self):
259+
"""Reset the variables."""
260+
self.in_idx[:] = self.num_step - 1
261+
self.out_idx[:] = 0
262+
self.data[:] = 0
263+
240264
@property
241265
def oldest(self):
242266
return self.pull()
@@ -265,12 +289,6 @@ def update(self, _t=None, _dt=None, **kwargs):
265289
self.in_idx[:] = (self.in_idx + 1) % self.num_step
266290
self.out_idx[:] = (self.out_idx + 1) % self.num_step
267291

268-
def reset(self):
269-
"""Reset the variables."""
270-
self.in_idx[:] = self.num_step - 1
271-
self.out_idx[:] = 0
272-
self.data[:] = 0
273-
274292

275293
class NeuGroup(DynamicalSystem):
276294
"""Base class to model neuronal groups.
@@ -476,15 +494,15 @@ def register_delay(
476494
self.local_delay_vars[name] = self.global_delay_vars[name]
477495
else:
478496
if self.global_delay_vars[name].num_delay_step - 1 < max_delay_step:
479-
self.global_delay_vars[name].init(delay_target, max_delay_step, initial_delay_data)
497+
self.global_delay_vars[name].reset(delay_target, max_delay_step, initial_delay_data)
480498
self.register_implicit_nodes(self.global_delay_vars)
481499
return delay_step
482500

483-
def get_delay(
501+
def get_delay_data(
484502
self,
485503
name: str,
486-
delay_step: Union[int, bm.JaxArray, bm.ndarray],
487-
indices=None,
504+
delay_step: Union[int, bm.JaxArray, jnp.DeviceArray],
505+
indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None,
488506
):
489507
"""Get delay data according to the provided delay steps.
490508
@@ -494,7 +512,7 @@ def get_delay(
494512
The delay variable name.
495513
delay_step: int, JaxArray, ndarray
496514
The delay length.
497-
indices: optional, JaxArray, ndarray
515+
indices: optional, int, JaxArray, ndarray
498516
The indices of the delay.
499517
500518
Returns
@@ -522,10 +540,31 @@ def get_delay(
522540
def update_delay(
523541
self,
524542
name: str,
525-
delay_target: Union[int, bm.JaxArray, bm.ndarray]
543+
delay_data: Union[float, bm.JaxArray, jnp.ndarray]
544+
):
545+
"""Update the delay according to the delay data.
546+
547+
Parameters
548+
----------
549+
name: str
550+
The name of the delay.
551+
delay_data: float, JaxArray, ndarray
552+
The delay data to update at the current time.
553+
"""
554+
if name in self.local_delay_vars:
555+
return self.local_delay_vars[name].update(delay_data)
556+
else:
557+
if name not in self.global_delay_vars:
558+
raise ValueError(f'{name} is not defined in delay variables.')
559+
560+
def reset_delay(
561+
self,
562+
name: str,
563+
delay_target: Union[bm.JaxArray, jnp.DeviceArray]
526564
):
565+
"""Reset the delay variable."""
527566
if name in self.local_delay_vars:
528-
return self.local_delay_vars[name].update(delay_target)
567+
return self.local_delay_vars[name].reset(delay_target)
529568
else:
530569
if name not in self.global_delay_vars:
531570
raise ValueError(f'{name} is not defined in delay variables.')

brainpy/dyn/neurons/biological_models.py

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -142,37 +142,41 @@ class HH(NeuGroup):
142142
>>> plt.yticks([])
143143
>>> plt.show()
144144
145-
146-
**Model Parameters**
147-
148-
============= ============== ======== ====================================
149-
**Parameter** **Init Value** **Unit** **Explanation**
150-
------------- -------------- -------- ------------------------------------
151-
V_th 20. mV the spike threshold.
152-
C 1. ufarad capacitance.
153-
E_Na 50. mV reversal potential of sodium.
154-
E_K -77. mV reversal potential of potassium.
155-
E_leak 54.387 mV reversal potential of unspecific.
156-
g_Na 120. msiemens conductance of sodium channel.
157-
g_K 36. msiemens conductance of potassium channel.
158-
g_leak .03 msiemens conductance of unspecific channels.
159-
============= ============== ======== ====================================
160-
161-
**Model Variables**
162-
163-
================== ================= =========================================================
164-
**Variables name** **Initial Value** **Explanation**
165-
------------------ ----------------- ---------------------------------------------------------
166-
V -65 Membrane potential.
167-
m 0.05 gating variable of the sodium ion channel.
168-
n 0.32 gating variable of the potassium ion channel.
169-
h 0.60 gating variable of the sodium ion channel.
170-
input 0 External and synaptic input current.
171-
spike False Flag to mark whether the neuron is spiking.
172-
t_last_spike -1e7 Last spike time stamp.
173-
================== ================= =========================================================
174-
175-
**References**
145+
Parameters
146+
----------
147+
size: sequence of int, int
148+
The size of the neuron group.
149+
ENa: float, JaxArray, ndarray, Initializer, callable
150+
The reversal potential of sodium. Default is 50 mV.
151+
gNa: float, JaxArray, ndarray, Initializer, callable
152+
The maximum conductance of sodium channel. Default is 120 msiemens.
153+
EK: float, JaxArray, ndarray, Initializer, callable
154+
The reversal potential of potassium. Default is -77 mV.
155+
gK: float, JaxArray, ndarray, Initializer, callable
156+
The maximum conductance of potassium channel. Default is 36 msiemens.
157+
EL: float, JaxArray, ndarray, Initializer, callable
158+
The reversal potential of learky channel. Default is -54.387 mV.
159+
gL: float, JaxArray, ndarray, Initializer, callable
160+
The conductance of learky channel. Default is 0.03 msiemens.
161+
V_th: float, JaxArray, ndarray, Initializer, callable
162+
The threshold of the membrane spike. Default is 20 mV.
163+
C: float, JaxArray, ndarray, Initializer, callable
164+
The membrane capacitance. Default is 1 ufarad.
165+
V_initializer: JaxArray, ndarray, Initializer, callable
166+
The initializer of membrane potential.
167+
m_initializer: JaxArray, ndarray, Initializer, callable
168+
The initializer of m channel.
169+
h_initializer: JaxArray, ndarray, Initializer, callable
170+
The initializer of h channel.
171+
n_initializer: JaxArray, ndarray, Initializer, callable
172+
The initializer of n channel.
173+
method: str
174+
The numerical integration method.
175+
name: str
176+
The group name.
177+
178+
References
179+
----------
176180
177181
.. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description
178182
of membrane current and its application to conduction and excitation
@@ -214,22 +218,37 @@ def __init__(
214218
self.C = init_param(C, self.num, allow_none=False)
215219
self.V_th = init_param(V_th, self.num, allow_none=False)
216220

217-
# variables
221+
# initializers
218222
check_initializer(m_initializer, 'm_initializer', allow_none=False)
219223
check_initializer(h_initializer, 'h_initializer', allow_none=False)
220224
check_initializer(n_initializer, 'n_initializer', allow_none=False)
221225
check_initializer(V_initializer, 'V_initializer', allow_none=False)
222-
self.m = bm.Variable(init_param(m_initializer, (self.num,)))
223-
self.h = bm.Variable(init_param(h_initializer, (self.num,)))
224-
self.n = bm.Variable(init_param(n_initializer, (self.num,)))
225-
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
226+
self._m_initializer = m_initializer
227+
self._h_initializer = h_initializer
228+
self._n_initializer = n_initializer
229+
self._V_initializer = V_initializer
230+
231+
# variables
232+
self.m = bm.Variable(init_param(self._m_initializer, (self.num,)))
233+
self.h = bm.Variable(init_param(self._h_initializer, (self.num,)))
234+
self.n = bm.Variable(init_param(self._n_initializer, (self.num,)))
235+
self.V = bm.Variable(init_param(self._V_initializer, (self.num,)))
226236
self.input = bm.Variable(bm.zeros(self.num))
227237
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
228238
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
229239

230240
# integral
231241
self.integral = odeint(method=method, f=self.derivative)
232242

243+
def reset(self):
244+
self.m.value = init_param(self._m_initializer, (self.num,))
245+
self.h.value = init_param(self._h_initializer, (self.num,))
246+
self.n.value = init_param(self._n_initializer, (self.num,))
247+
self.V.value = init_param(self._V_initializer, (self.num,))
248+
self.input[:] = 0
249+
self.spike[:] = False
250+
self.t_last_spike[:] = -1e7
251+
233252
def dm(self, m, t, V):
234253
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
235254
beta = 4.0 * bm.exp(-(V + 65) / 18)
@@ -336,20 +355,8 @@ class MorrisLecar(NeuGroup):
336355
V_th 10 mV The spike threshold.
337356
============= ============== ======== =======================================================
338357
339-
**Model Variables**
340-
341-
================== ================= =========================================================
342-
**Variables name** **Initial Value** **Explanation**
343-
------------------ ----------------- ---------------------------------------------------------
344-
V -20 Membrane potential.
345-
W 0.02 Gating variable, refers to the fraction of
346-
opened K+ channels.
347-
input 0 External and synaptic input current.
348-
spike False Flag to mark whether the neuron is spiking.
349-
t_last_spike -1e7 Last spike time stamp.
350-
================== ================= =========================================================
351-
352-
**References**
358+
References
359+
----------
353360
354361
.. [1] Meier, Stephen R., Jarrett L. Lancaster, and Joseph M. Starobin.
355362
"Bursting regimes in a reaction-diffusion system with action
@@ -398,9 +405,13 @@ def __init__(
398405
self.phi = init_param(phi, self.num, allow_none=False)
399406
self.V_th = init_param(V_th, self.num, allow_none=False)
400407

401-
# vars
408+
# initializers
402409
check_initializer(V_initializer, 'V_initializer', allow_none=False)
403410
check_initializer(W_initializer, 'W_initializer', allow_none=False)
411+
self._W_initializer = W_initializer
412+
self._V_initializer = V_initializer
413+
414+
# variables
404415
self.W = bm.Variable(init_param(W_initializer, (self.num,)))
405416
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
406417
self.input = bm.Variable(bm.zeros(self.num))
@@ -410,6 +421,13 @@ def __init__(
410421
# integral
411422
self.integral = odeint(method=method, f=self.derivative)
412423

424+
def reset(self):
425+
self.W.value = init_param(self._W_initializer, (self.num,))
426+
self.V.value = init_param(self._V_initializer, (self.num,))
427+
self.input.value = bm.zeros(self.num)
428+
self.spike.value = bm.zeros(self.num, dtype=bool)
429+
self.t_last_spike.value = bm.ones(self.num) * -1e7
430+
413431
def dV(self, V, t, W, I_ext):
414432
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
415433
I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)

brainpy/dyn/neurons/fractional_models.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,15 @@ def __init__(
110110
self.Vth = init_param(Vth, self.num, allow_none=False)
111111
self.delta = init_param(delta, self.num, allow_none=False)
112112

113-
# variables
113+
# initializers
114114
check_initializer(V_initializer, 'V_initializer', allow_none=False)
115115
check_initializer(w_initializer, 'w_initializer', allow_none=False)
116116
check_initializer(y_initializer, 'y_initializer', allow_none=False)
117+
self._V_initializer = V_initializer
118+
self._w_initializer = w_initializer
119+
self._y_initializer = y_initializer
120+
121+
# variables
117122
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
118123
self.w = bm.Variable(init_param(w_initializer, (self.num,)))
119124
self.y = bm.Variable(init_param(y_initializer, (self.num,)))
@@ -127,6 +132,16 @@ def __init__(
127132
num_memory=num_memory,
128133
inits=[self.V, self.w, self.y])
129134

135+
def reset(self):
136+
self.V.value = init_param(self._V_initializer, (self.num,))
137+
self.w.value = init_param(self._w_initializer, (self.num,))
138+
self.y.value = init_param(self._y_initializer, (self.num,))
139+
self.input[:] = 0
140+
self.spike[:] = False
141+
self.t_last_spike[:] = -1e7
142+
# integral function reset
143+
self.integral.reset([self.V, self.w, self.y])
144+
130145
def dV(self, V, t, w, y):
131146
return V - V ** 3 / 3 - w + y + self.input
132147

@@ -149,14 +164,6 @@ def update(self, _t, _dt):
149164
self.y.value = y
150165
self.input[:] = 0.
151166

152-
def set_init(self, values: dict):
153-
for k, v in values.items():
154-
if k not in self.integral.inits:
155-
raise ValueError(f'Variable "{k}" is not defined in this model.')
156-
variable = getattr(self, k)
157-
variable[:] = v
158-
self.integral.inits[k][:] = v
159-
160167

161168
class FractionalIzhikevich(FractionalNeuron):
162169
r"""Fractional-order Izhikevich model [10]_.
@@ -248,9 +255,13 @@ def __init__(
248255
self.R = init_param(R, self.num, allow_none=False)
249256
self.V_th = init_param(V_th, self.num, allow_none=False)
250257

251-
# variables
258+
# initializers
252259
check_initializer(V_initializer, 'V_initializer', allow_none=False)
253260
check_initializer(u_initializer, 'u_initializer', allow_none=False)
261+
self._V_initializer = V_initializer
262+
self._u_initializer = u_initializer
263+
264+
# variables
254265
self.V = bm.Variable(init_param(V_initializer, (self.num,)))
255266
self.u = bm.Variable(init_param(u_initializer, (self.num,)))
256267
self.input = bm.Variable(bm.zeros(self.num))
@@ -264,6 +275,15 @@ def __init__(
264275
num_step=num_step,
265276
inits=[self.V, self.u])
266277

278+
def reset(self):
279+
self.V.value = init_param(self._V_initializer, (self.num,))
280+
self.u.value = init_param(self._u_initializer, (self.num,))
281+
self.input[:] = 0
282+
self.spike[:] = False
283+
self.t_last_spike[:] = -1e7
284+
# integral function reset
285+
self.integral.reset([self.V, self.u])
286+
267287
def dV(self, V, t, u, I_ext):
268288
dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext
269289
return dVdt / self.tau
@@ -284,11 +304,3 @@ def update(self, _t, _dt):
284304
self.u.value = bm.where(spikes, u + self.d, u)
285305
self.spike.value = spikes
286306
self.input[:] = 0.
287-
288-
def set_init(self, values: dict):
289-
for k, v in values.items():
290-
if k not in self.integral.inits:
291-
raise ValueError(f'Variable "{k}" is not defined in this model.')
292-
variable = getattr(self, k)
293-
variable[:] = v
294-
self.integral.inits[k][:] = v

0 commit comments

Comments
 (0)