Skip to content

Commit b59940c

Browse files
committed
update apis and tests
1 parent 41c3ef9 commit b59940c

File tree

5 files changed

+66
-32
lines changed

5 files changed

+66
-32
lines changed

brainpy/base/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def vars(self, method='absolute', level=-1, include_self=True):
124124
v = getattr(node, k)
125125
if isinstance(v, math.Variable):
126126
if k not in node._excluded_vars:
127-
# if not k.startswith('_') and not k.endswith('_'):
128127
gather[f'{node_path}.{k}' if node_path else k] = v
129128
gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
130129
return gather

brainpy/base/tests/test_collector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def test_net_vars_2():
273273

274274
def test_hidden_variables():
275275
class BPClass(bp.base.Base):
276+
_excluded_vars = ('_rng_', )
277+
276278
def __init__(self):
277279
super(BPClass, self).__init__()
278280

brainpy/dyn/neurons/biological_models.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
from typing import Union, Callable
3+
from typing import Union, Callable, Optional
44

55
import brainpy.math as bm
66
from brainpy.dyn.base import NeuGroup
@@ -204,9 +204,9 @@ def __init__(
204204
V_th: Union[float, Tensor, Initializer, Callable] = 20.,
205205
C: Union[float, Tensor, Initializer, Callable] = 1.0,
206206
V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.),
207-
m_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.5),
208-
h_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.6),
209-
n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32),
207+
m_initializer: Optional[Union[Initializer, Callable, Tensor]] = None,
208+
h_initializer: Optional[Union[Initializer, Callable, Tensor]] = None,
209+
n_initializer: Optional[Union[Initializer, Callable, Tensor]] = None,
210210
noise: Union[float, Tensor, Initializer, Callable] = None,
211211
method: str = 'exp_auto',
212212
name: str = None,
@@ -233,20 +233,29 @@ def __init__(
233233
self.noise = init_noise(noise, self.varshape, num_vars=4)
234234

235235
# initializers
236-
check_initializer(m_initializer, 'm_initializer', allow_none=False)
237-
check_initializer(h_initializer, 'h_initializer', allow_none=False)
238-
check_initializer(n_initializer, 'n_initializer', allow_none=False)
236+
check_initializer(m_initializer, 'm_initializer', allow_none=True)
237+
check_initializer(h_initializer, 'h_initializer', allow_none=True)
238+
check_initializer(n_initializer, 'n_initializer', allow_none=True)
239239
check_initializer(V_initializer, 'V_initializer', allow_none=False)
240240
self._m_initializer = m_initializer
241241
self._h_initializer = h_initializer
242242
self._n_initializer = n_initializer
243243
self._V_initializer = V_initializer
244244

245245
# variables
246-
self.m = variable(self._m_initializer, mode, self.varshape)
247-
self.h = variable(self._h_initializer, mode, self.varshape)
248-
self.n = variable(self._n_initializer, mode, self.varshape)
249246
self.V = variable(self._V_initializer, mode, self.varshape)
247+
if self._m_initializer is None:
248+
self.m = bm.Variable(self.m_inf(self.V.value))
249+
else:
250+
self.m = variable(self._m_initializer, mode, self.varshape)
251+
if self._h_initializer is None:
252+
self.h = bm.Variable(self.h_inf(self.V.value))
253+
else:
254+
self.h = variable(self._h_initializer, mode, self.varshape)
255+
if self._n_initializer is None:
256+
self.n = bm.Variable(self.n_inf(self.V.value))
257+
else:
258+
self.n = variable(self._n_initializer, mode, self.varshape)
250259
self.input = variable(bm.zeros, mode, self.varshape)
251260
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
252261

@@ -256,32 +265,41 @@ def __init__(
256265
else:
257266
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
258267

268+
# m channel
269+
m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
270+
m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
271+
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
272+
dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
273+
274+
# h channel
275+
h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.)
276+
h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10))
277+
h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))
278+
dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h
279+
280+
# n channel
281+
n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
282+
n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
283+
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
284+
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
285+
259286
def reset_state(self, batch_size=None):
260-
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
261-
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
262-
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
263287
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
288+
if self._m_initializer is None:
289+
self.m.value = self.m_inf(self.V.value)
290+
else:
291+
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
292+
if self._h_initializer is None:
293+
self.h.value = self.h_inf(self.V.value)
294+
else:
295+
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
296+
if self._n_initializer is None:
297+
self.n.value = self.n_inf(self.V.value)
298+
else:
299+
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
264300
self.input.value = variable(bm.zeros, batch_size, self.varshape)
265301
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
266302

267-
def dm(self, m, t, V):
268-
alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
269-
beta = 4.0 * bm.exp(-(V + 65) / 18)
270-
dmdt = alpha * (1 - m) - beta * m
271-
return dmdt
272-
273-
def dh(self, h, t, V):
274-
alpha = 0.07 * bm.exp(-(V + 65) / 20.)
275-
beta = 1 / (1 + bm.exp(-(V + 35) / 10))
276-
dhdt = alpha * (1 - h) - beta * h
277-
return dhdt
278-
279-
def dn(self, n, t, V):
280-
alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
281-
beta = 0.125 * bm.exp(-(V + 65) / 80)
282-
dndt = alpha * (1 - n) - beta * n
283-
return dndt
284-
285303
def dV(self, V, t, m, h, n, I_ext):
286304
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
287305
I_K = (self.gK * n ** 4.0) * (V - self.EK)

brainpy/visualization/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def animate_2D(values,
9393
frame_delay=frame_delay, frame_step=frame_step, title_size=title_size,
9494
figsize=figsize, gif_dpi=gif_dpi, video_fps=video_fps, save_path=save_path, show=show)
9595

96+
@staticmethod
97+
def remove_axis(ax, *pos):
98+
from .plots import remove_axis
99+
return remove_axis(ax, *pos)
100+
96101
@staticmethod
97102
def plot_style1(fontsize=22,
98103
axes_edgecolor='black',

brainpy/visualization/plots.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'raster_plot',
1818
'animate_2D',
1919
'animate_1D',
20+
'remove_axis',
2021
]
2122

2223

@@ -504,3 +505,12 @@ def frame(t):
504505
else:
505506
anim_result.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000)
506507
return fig
508+
509+
510+
def remove_axis(ax, *pos):
511+
for p in pos:
512+
if p not in ['left', 'right', 'top', 'bottom']:
513+
raise ValueError
514+
ax.spine[p].set_visible(False)
515+
516+

0 commit comments

Comments
 (0)