Skip to content

Commit 417c1ed

Browse files
committed
Update biological_models.py
1 parent 21b4bc5 commit 417c1ed

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

brainpy/dyn/neurons/biological_models.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Union, Callable
3+
from typing import Union, Callable, Optional
44

55
import brainpy.math as bm
66
from brainpy.dyn.base import NeuGroup
@@ -204,9 +204,9 @@ def __init__(
204204
V_th: Union[float, Array, Initializer, Callable] = 20.,
205205
C: Union[float, Array, Initializer, Callable] = 1.0,
206206
V_initializer: Union[Initializer, Callable, Array] = Uniform(-70, -60.),
207-
m_initializer: Union[Initializer, Callable, Array] = OneInit(0.5),
208-
h_initializer: Union[Initializer, Callable, Array] = OneInit(0.6),
209-
n_initializer: Union[Initializer, Callable, Array] = OneInit(0.32),
207+
m_initializer: Optional[Union[Initializer, Callable, Array]] = None,
208+
h_initializer: Optional[Union[Initializer, Callable, Array]] = None,
209+
n_initializer: Optional[Union[Initializer, Callable, Array]] = None,
210210
noise: Union[float, Array, Initializer, Callable] = None,
211211
method: str = 'exp_auto',
212212
name: str = None,
@@ -233,20 +233,29 @@ def __init__(
233233
self.noise = init_noise(noise, self.varshape, num_vars=4)
234234

235235
# initializers
236-
check_initializer(m_initializer, 'm_initializer', allow_none=False)
237-
check_initializer(h_initializer, 'h_initializer', allow_none=False)
238-
check_initializer(n_initializer, 'n_initializer', allow_none=False)
236+
check_initializer(m_initializer, 'm_initializer', allow_none=True)
237+
check_initializer(h_initializer, 'h_initializer', allow_none=True)
238+
check_initializer(n_initializer, 'n_initializer', allow_none=True)
239239
check_initializer(V_initializer, 'V_initializer', allow_none=False)
240240
self._m_initializer = m_initializer
241241
self._h_initializer = h_initializer
242242
self._n_initializer = n_initializer
243243
self._V_initializer = V_initializer
244244

245245
# variables
246-
self.m = variable(self._m_initializer, mode, self.varshape)
247-
self.h = variable(self._h_initializer, mode, self.varshape)
248-
self.n = variable(self._n_initializer, mode, self.varshape)
249246
self.V = variable(self._V_initializer, mode, self.varshape)
247+
if self._m_initializer is None:
248+
self.m = bm.Variable(self.m_inf(self.V.value))
249+
else:
250+
self.m = variable(self._m_initializer, mode, self.varshape)
251+
if self._h_initializer is None:
252+
self.h = bm.Variable(self.h_inf(self.V.value))
253+
else:
254+
self.h = variable(self._h_initializer, mode, self.varshape)
255+
if self._n_initializer is None:
256+
self.n = bm.Variable(self.n_inf(self.V.value))
257+
else:
258+
self.n = variable(self._n_initializer, mode, self.varshape)
250259
self.input = variable(bm.zeros, mode, self.varshape)
251260
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
252261

@@ -256,32 +265,41 @@ def __init__(
256265
else:
257266
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
258267

268+
# m channel
269+
m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
270+
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
271+
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
272+
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
273+
274+
# h channel
275+
h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.)
276+
h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10))
277+
h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))
278+
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h
279+
280+
# n channel
281+
n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
282+
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
283+
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
284+
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
285+
259286
def reset_state(self, batch_size=None):
260-
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
261-
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
262-
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
263287
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
288+
if self._m_initializer is None:
289+
self.m.value = self.m_inf(self.V.value)
290+
else:
291+
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
292+
if self._h_initializer is None:
293+
self.h.value = self.h_inf(self.V.value)
294+
else:
295+
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
296+
if self._n_initializer is None:
297+
self.n.value = self.n_inf(self.V.value)
298+
else:
299+
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
264300
self.input.value = variable(bm.zeros, batch_size, self.varshape)
265301
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
266302

267-
def dm(self, m, t, V):
268-
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
269-
beta = 4.0 * bm.exp(-(V + 65) / 18)
270-
dmdt = alpha * (1 - m) - beta * m
271-
return dmdt
272-
273-
def dh(self, h, t, V):
274-
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
275-
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
276-
dhdt = alpha * (1 - h) - beta * h
277-
return dhdt
278-
279-
def dn(self, n, t, V):
280-
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
281-
beta = 0.125 * bm.exp(-(V + 65) / 80)
282-
dndt = alpha * (1 - n) - beta * n
283-
return dndt
284-
285303
def dV(self, V, t, m, h, n, I_ext):
286304
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
287305
I_K = (self.gK * n ** 4.0) * (V - self.EK)

0 commit comments

Comments
 (0)