Skip to content

Commit 8ef3a56

Browse files
committed
fix ExponentialEuler and sparse_matmul op
1 parent 317ddf3 commit 8ef3a56

File tree

4 files changed

+12
-225
lines changed

4 files changed

+12
-225
lines changed

brainpy/integrators/ode/tests/test_ode_keywords_for_exp_euler.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,33 +38,3 @@ def func(m, t, dt):
3838
return dmdt
3939

4040
odeint(method='exponential_euler', show_code=True, f=func)
41-
42-
def test4(self):
43-
with pytest.raises(errors.CodeError):
44-
def func(m, t, m_new):
45-
alpha = 0.1 * (m_new + 40) / (1 - np.exp(-(m_new + 40) / 10))
46-
beta = 4.0 * np.exp(-(m_new + 65) / 18)
47-
dmdt = alpha * (1 - m) - beta * m
48-
return dmdt
49-
50-
odeint(method='exponential_euler', show_code=True, f=func)
51-
52-
def test5(self):
53-
with pytest.raises(errors.CodeError):
54-
def func(m, t, exp):
55-
alpha = 0.1 * (exp + 40) / (1 - np.exp(-(exp + 40) / 10))
56-
beta = 4.0 * np.exp(-(exp + 65) / 18)
57-
dmdt = alpha * (1 - m) - beta * m
58-
return dmdt
59-
60-
odeint(method='exponential_euler', show_code=True, f=func)
61-
62-
def test6(self):
63-
with pytest.raises(errors.CodeError):
64-
def func(math, t, exp):
65-
alpha = 0.1 * (exp + 40) / (1 - np.exp(-(exp + 40) / 10))
66-
beta = 4.0 * np.exp(-(exp + 65) / 18)
67-
dmdt = alpha * (1 - math) - beta * math
68-
return dmdt
69-
70-
odeint(method='exponential_euler', show_code=True, f=func)
Lines changed: 10 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# -*- coding: utf-8 -*-
22

33
import unittest
4-
import pytest
4+
5+
import matplotlib.pyplot as plt
56

67
import brainpy as bp
78
import brainpy.math as bm
89
from brainpy.integrators.ode.exponential import ExponentialEuler
910

10-
plt = None
11+
block = False
1112

1213

1314
class TestExpnentialEuler(unittest.TestCase):
@@ -32,205 +33,19 @@ def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C):
3233

3334
return dVdt, dmdt, dhdt, dndt
3435

35-
ExponentialEuler(f=drivative, show_code=True, dt=0.01, var_type='SCALAR')
36-
37-
def test_return_expr(self):
38-
def derivative(s, t, tau):
39-
return -s / tau
40-
41-
with pytest.raises(bp.errors.DiffEqError):
42-
ExponentialEuler(f=derivative, show_code=True, dt=0.01, var_type='SCALAR', )
43-
44-
def test_return_expr2(self):
45-
def derivative(s, v, t, tau):
46-
dv = -v + 1
47-
return -s / tau, dv
48-
49-
with pytest.raises(bp.errors.DiffEqError):
50-
ExponentialEuler(f=derivative, show_code=True, dt=0.01, var_type='SCALAR', )
51-
52-
def test_return_expr3(self):
53-
f = lambda s, t, tau: -s / tau
54-
with pytest.raises(bp.errors.AnalyzerError) as excinfo:
55-
ExponentialEuler(f=f, show_code=True, dt=0.01, var_type='SCALAR', )
56-
57-
def test_nonlinear_eq1_vdp(self):
58-
def vdp_derivative(x, y, t, mu):
59-
dx = mu * (x - x ** 3 / 3 - y)
60-
dy = x / mu
61-
return dx, dy
62-
63-
ExponentialEuler(f=vdp_derivative, show_code=True, dt=0.01)
64-
65-
def test_nonlinear_eq2_reduced_trn(self):
66-
T = 36.
67-
phi_m = phi_h = phi_n = 3 ** ((T - 36) / 10)
68-
# parameters of IT
69-
E_T = 120.
70-
phi_p = 5 ** ((T - 24) / 10)
71-
phi_q = 3 ** ((T - 24) / 10)
72-
p_half, p_k = -52., 7.4
73-
q_half, q_k = -80., -5.
74-
g_Na = 100.
75-
E_Na = 50.
76-
g_K = 10.
77-
# parameters of V
78-
C, Vth, area = 1., 20., 1.43e-4
79-
V_factor = 1e-3 / area
80-
81-
def reduced_trn_derivative(V, y, z, t, Isyn, b, rho_p, g_T, g_L, g_KL, E_L, E_KL, IT_th, NaK_th):
82-
# m channel
83-
t1 = 13. - V + NaK_th
84-
t1_exp = bm.exp(t1 / 4.)
85-
m_alpha_by_V = 0.32 * t1 / (t1_exp - 1.) # \alpha_m(V)
86-
m_alpha_by_V_diff = (-0.32 * (t1_exp - 1.) + 0.08 * t1 * t1_exp) / (t1_exp - 1.) ** 2 # \alpha_m'(V)
87-
t2 = V - 40. - NaK_th
88-
t2_exp = bm.exp(t2 / 5.)
89-
m_beta_by_V = 0.28 * t2 / (t2_exp - 1.) # \beta_m(V)
90-
m_beta_by_V_diff = (0.28 * (t2_exp - 1) - 0.056 * t2 * t2_exp) / (t2_exp - 1) ** 2 # \beta_m'(V)
91-
m_tau_by_V = 1. / phi_m / (m_alpha_by_V + m_beta_by_V) # \tau_m(V)
92-
m_inf_by_V = m_alpha_by_V / (m_alpha_by_V + m_beta_by_V) # \m_{\infty}(V)
93-
m_inf_by_V_diff = (m_alpha_by_V_diff * m_beta_by_V - m_alpha_by_V * m_beta_by_V_diff) / \
94-
(m_alpha_by_V + m_beta_by_V) ** 2 # \m_{\infty}'(V)
95-
96-
# h channel
97-
h_alpha_by_V = 0.128 * bm.exp((17. - V + NaK_th) / 18.) # \alpha_h(V)
98-
h_beta_by_V = 4. / (bm.exp((40. - V + NaK_th) / 5.) + 1.) # \beta_h(V)
99-
h_inf_by_V = h_alpha_by_V / (h_alpha_by_V + h_beta_by_V) # h_{\infty}(V)
100-
h_tau_by_V = 1. / phi_h / (h_alpha_by_V + h_beta_by_V) # \tau_h(V)
101-
h_alpha_by_y = 0.128 * bm.exp((17. - y + NaK_th) / 18.) # \alpha_h(y)
102-
t3 = bm.exp((40. - y + NaK_th) / 5.)
103-
h_beta_by_y = 4. / (t3 + 1.) # \beta_h(y)
104-
h_beta_by_y_diff = 0.8 * t3 / (1 + t3) ** 2 # \beta_h'(y)
105-
h_inf_by_y = h_alpha_by_y / (h_alpha_by_y + h_beta_by_y) # h_{\infty}(y)
106-
h_alpha_by_y_diff = - h_alpha_by_y / 18. # \alpha_h'(y)
107-
h_inf_by_y_diff = (h_alpha_by_y_diff * h_beta_by_y - h_alpha_by_y * h_beta_by_y_diff) / \
108-
(h_beta_by_y + h_alpha_by_y) ** 2 # h_{\infty}'(y)
109-
110-
# n channel
111-
t4 = (15. - V + NaK_th)
112-
n_alpha_by_V = 0.032 * t4 / (bm.exp(t4 / 5.) - 1.) # \alpha_n(V)
113-
n_beta_by_V = b * bm.exp((10. - V + NaK_th) / 40.) # \beta_n(V)
114-
n_tau_by_V = 1. / (n_alpha_by_V + n_beta_by_V) / phi_n # \tau_n(V)
115-
n_inf_by_V = n_alpha_by_V / (n_alpha_by_V + n_beta_by_V) # n_{\infty}(V)
116-
t5 = (15. - y + NaK_th)
117-
t5_exp = bm.exp(t5 / 5.)
118-
n_alpha_by_y = 0.032 * t5 / (t5_exp - 1.) # \alpha_n(y)
119-
t6 = bm.exp((10. - y + NaK_th) / 40.)
120-
n_beta_y = b * t6 # \beta_n(y)
121-
n_inf_by_y = n_alpha_by_y / (n_alpha_by_y + n_beta_y) # n_{\infty}(y)
122-
n_alpha_by_y_diff = (0.0064 * t5 * t5_exp - 0.032 * (t5_exp - 1.)) / (t5_exp - 1.) ** 2 # \alpha_n'(y)
123-
n_beta_by_y_diff = -n_beta_y / 40 # \beta_n'(y)
124-
n_inf_by_y_diff = (n_alpha_by_y_diff * n_beta_y - n_alpha_by_y * n_beta_by_y_diff) / \
125-
(n_alpha_by_y + n_beta_y) ** 2 # n_{\infty}'(y)
126-
127-
# p channel
128-
p_inf_by_V = 1. / (1. + bm.exp((p_half - V + IT_th) / p_k)) # p_{\infty}(V)
129-
p_tau_by_V = (3 + 1. / (bm.exp((V + 27. - IT_th) / 10.) +
130-
bm.exp(-(V + 102. - IT_th) / 15.))) / phi_p # \tau_p(V)
131-
t7 = bm.exp((p_half - y + IT_th) / p_k)
132-
p_inf_by_y = 1. / (1. + t7) # p_{\infty}(y)
133-
p_inf_by_y_diff = t7 / p_k / (1. + t7) ** 2 # p_{\infty}'(y)
134-
135-
# p channel
136-
q_inf_by_V = 1. / (1. + bm.exp((q_half - V + IT_th) / q_k)) # q_{\infty}(V)
137-
t8 = bm.exp((q_half - z + IT_th) / q_k)
138-
q_inf_by_z = 1. / (1. + t8) # q_{\infty}(z)
139-
q_inf_diff_z = t8 / q_k / (1. + t8) ** 2 # q_{\infty}'(z)
140-
q_tau_by_V = (85. + 1 / (bm.exp((V + 48. - IT_th) / 4.) +
141-
bm.exp(-(V + 407. - IT_th) / 50.))) / phi_q # \tau_q(V)
142-
143-
# ----
144-
# x
145-
# ----
146-
147-
gNa = g_Na * m_inf_by_V ** 3 * h_inf_by_y # gNa
148-
gK = g_K * n_inf_by_y ** 4 # gK
149-
gT = g_T * p_inf_by_y * p_inf_by_y * q_inf_by_z # gT
150-
FV = gNa + gK + gT + g_L + g_KL # dF/dV
151-
Fm = 3 * g_Na * h_inf_by_y * (V - E_Na) * m_inf_by_V * m_inf_by_V * m_inf_by_V_diff # dF/dvm
152-
t9 = C / m_tau_by_V
153-
t10 = FV + Fm
154-
t11 = t9 + FV
155-
rho_V = (t11 - bm.sqrt(bm.maximum(t11 ** 2 - 4 * t9 * t10, 0.))) / 2 / t10 # rho_V
156-
INa = gNa * (V - E_Na)
157-
IK = gK * (V - E_KL)
158-
IT = gT * (V - E_T)
159-
IL = g_L * (V - E_L)
160-
IKL = g_KL * (V - E_KL)
161-
Iext = V_factor * Isyn
162-
dVdt = rho_V * (-INa - IK - IT - IL - IKL + Iext) / C
163-
164-
# ----
165-
# y
166-
# ----
167-
168-
Fvh = g_Na * m_inf_by_V ** 3 * (V - E_Na) * h_inf_by_y_diff # dF/dvh
169-
Fvn = 4 * g_K * (V - E_KL) * n_inf_by_y ** 3 * n_inf_by_y_diff # dF/dvn
170-
f4 = Fvh + Fvn
171-
rho_h = (1 - rho_p) * Fvh / f4
172-
rho_n = (1 - rho_p) * Fvn / f4
173-
fh = (h_inf_by_V - h_inf_by_y) / h_tau_by_V / h_inf_by_y_diff
174-
fn = (n_inf_by_V - n_inf_by_y) / n_tau_by_V / n_inf_by_y_diff
175-
fp = (p_inf_by_V - p_inf_by_y) / p_tau_by_V / p_inf_by_y_diff
176-
dydt = rho_h * fh + rho_n * fn + rho_p * fp
177-
178-
# ----
179-
# z
180-
# ----
181-
182-
dzdt = (q_inf_by_V - q_inf_by_z) / q_tau_by_V / q_inf_diff_z
183-
184-
return dVdt, dydt, dzdt
185-
186-
with pytest.raises(bp.errors.DiffEqError):
187-
ExponentialEuler(f=reduced_trn_derivative, show_code=True, dt=0.01, timeout=5)
188-
189-
def test_nonlinear_eq3_adaptive_quadratic_if(self):
190-
def derivative(V, w, t, Iext, self):
191-
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + Iext) / self.tau
192-
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
193-
return dVdt, dwdt
194-
195-
ExponentialEuler(f=derivative, show_code=True, dt=0.01, timeout=5)
196-
197-
def test_nonlinear_eq4_exponentil_if(self):
198-
def derivative(V, t, Iext, self):
199-
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
200-
dvdt = (- (V - self.V_rest) + exp_v + self.R * Iext) / self.tau
201-
return dvdt
202-
203-
ExponentialEuler(f=derivative, show_code=True, dt=0.01, timeout=5)
204-
205-
def test_nonlinear_eq5_morris_lecar(self):
206-
def derivative(V, W, t, I_ext, self):
207-
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
208-
I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)
209-
I_K = self.g_K * W * (V - self.V_K)
210-
I_Leak = self.g_leak * (V - self.V_leak)
211-
dVdt = (- I_Ca - I_K - I_Leak + I_ext) / self.C
212-
213-
tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4)))
214-
W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4))
215-
dWdt = (W_inf - W) / tau_W
216-
return dVdt, dWdt
217-
218-
ExponentialEuler(f=derivative, show_code=True, dt=0.01, timeout=5)
36+
with self.assertRaises(bp.errors.DiffEqError):
37+
ExponentialEuler(f=drivative, show_code=True, dt=0.01, var_type='SCALAR')
21938

22039
def test1(self):
22140
def dev(x, t):
22241
dx = bm.power(x, 3)
22342
return dx
22443

225-
ExponentialEuler(f=dev, show_code=True, dt=0.01, timeout=5)
44+
ExponentialEuler(f=dev, show_code=True, dt=0.01)
22645

22746

22847
class TestExpEulerAuto(unittest.TestCase):
22948
def test_hh_model(self):
230-
global plt
231-
if plt is None:
232-
import matplotlib.pyplot as plt
233-
23449
class HH(bp.dyn.NeuGroup):
23550
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
23651
gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'):
@@ -298,7 +113,7 @@ def update(self, t, dt):
298113
plt.plot(runner1.mon.ts, runner1.mon.V, label='V')
299114
plt.plot(runner1.mon.ts, runner1.mon.h, label='h')
300115
plt.plot(runner1.mon.ts, runner1.mon.n, label='n')
301-
# plt.show()
116+
plt.show(block=block)
302117

303118
hh2 = HH(1, method='exp_euler_auto')
304119
runner2 = bp.dyn.DSRunner(hh2, inputs=('input', 2.), monitors=['V', 'h', 'n'])
@@ -307,8 +122,10 @@ def update(self, t, dt):
307122
plt.plot(runner2.mon.ts, runner2.mon.V, label='V')
308123
plt.plot(runner2.mon.ts, runner2.mon.h, label='h')
309124
plt.plot(runner2.mon.ts, runner2.mon.n, label='n')
310-
plt.show()
125+
plt.show(block=block)
311126

312127
diff = (runner2.mon.V - runner1.mon.V).mean()
313128
self.assertTrue(diff < 1e0)
314129

130+
plt.close()
131+

brainpy/math/operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def _matmul_with_right_sparse(
738738
if dense.ndim == 2:
739739
A = dense[:, rows]
740740
prod = (A * values).T
741-
res = jops.segment_sum(prod, cols, shape).T
741+
res = jops.segment_sum(prod, cols, shape[1]).T
742742
else:
743743
prod = dense[rows] * values
744744
res = jops.segment_sum(prod, cols, shape[1])

brainpy/math/tests/test_oprators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_right_sparse_matmul2(self):
133133
sparse_B = {'data': values, 'index': (rows, cols), 'shape': (3, 4)}
134134
A = jnp.arange(3)
135135

136-
print(bm.sparse_matmul(A, [values, (rows, cols)], 4))
136+
print(bm.sparse_matmul(A, sparse_B))
137137
print(jnp.dot(A, B))
138138

139139
self.assertTrue(bm.array_equal(bm.sparse_matmul(A, sparse_B),

0 commit comments

Comments
 (0)