Skip to content

Commit 1e9c8c2

Browse files
authored
fix bifurcation analysis bug and #287 (#289)
fix bifurcation analysis bug and #287
2 parents 147d3e8 + 6ff313f commit 1e9c8c2

File tree

9 files changed

+145
-29
lines changed

9 files changed

+145
-29
lines changed

brainpy/__init__.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.2.3.5"
4-
3+
__version__ = "2.2.3.6"
54

65
try:
76
import jaxlib
7+
88
del jaxlib
99
except ModuleNotFoundError:
1010
raise ModuleNotFoundError(
@@ -34,21 +34,17 @@
3434
3535
''') from None
3636

37-
3837
# fundamental modules
3938
from . import errors, tools, check, modes
4039

41-
4240
# "base" module
4341
from . import base
4442
from .base.base import Base
4543
from .base.collector import Collector, TensorCollector
4644

47-
4845
# math foundation
4946
from . import math
5047

51-
5248
# toolboxes
5349
from . import (
5450
connect, # synaptic connection
@@ -61,7 +57,6 @@
6157
algorithms, # online or offline training algorithms
6258
)
6359

64-
6560
# numerical integrators
6661
from . import integrators
6762
from .integrators import ode
@@ -72,7 +67,6 @@
7267
from .integrators.fde import fdeint
7368
from .integrators.joint_eq import JointEq
7469

75-
7670
# dynamics simulation
7771
from . import dyn
7872
from .dyn import (
@@ -82,10 +76,10 @@
8276
neurons, # neuron groups
8377
rates, # rate models
8478
synapses, # synaptic dynamics
85-
synouts, # synaptic output
79+
synouts, # synaptic output
8680
synplast, # synaptic plasticity
8781
)
88-
from brainpy.dyn.base import (
82+
from .dyn.base import (
8983
DynamicalSystem,
9084
Container,
9185
Sequential,
@@ -101,23 +95,33 @@
10195
)
10296
from .dyn.runners import *
10397

104-
10598
# dynamics training
10699
from . import train
107-
100+
from .train import (
101+
DSTrainer,
102+
OnlineTrainer, ForceTrainer,
103+
OfflineTrainer, RidgeTrainer,
104+
BPFF,
105+
BPTT,
106+
OnlineBPTT,
107+
)
108108

109109
# automatic dynamics analysis
110110
from . import analysis
111-
111+
from .analysis import (
112+
DSAnalyzer,
113+
PhasePlane1D, PhasePlane2D,
114+
Bifurcation1D, Bifurcation2D,
115+
FastSlow1D, FastSlow2D,
116+
SlowPointFinder,
117+
)
112118

113119
# running
114120
from . import running
115121

116-
117122
# "visualization" module, will be removed soon
118123
from .visualization import visualize
119124

120-
121125
# convenient access
122126
conn = connect
123127
init = initialize

brainpy/analysis/lowdim/lowdim_analyzer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def __init__(
137137
target_pars = dict()
138138
if not isinstance(target_pars, dict):
139139
raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.')
140-
for key in target_pars.keys():
140+
for key, value in target_pars.items():
141141
if key not in self.model.parameters:
142142
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
143-
value = self.target_vars[key]
144143
if value[0] > value[1]:
145-
raise errors.AnalyzerError(f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
144+
raise errors.AnalyzerError(
145+
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
146146

147147
self.target_pars = Collector(target_pars)
148148
self.target_par_names = list(self.target_pars.keys()) # list of target_pars

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax.numpy as jnp
66
from jax import vmap
77
import numpy as np
8+
from copy import deepcopy
89

910
import brainpy.math as bm
1011
from brainpy import errors
@@ -79,7 +80,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
7980
pyplot.figure(self.x_var)
8081
for fp_type, points in container.items():
8182
if len(points['x']):
82-
plot_style = plotstyle.plot_schema[fp_type]
83+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
8384
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
8485
pyplot.xlabel(self.target_par_names[0])
8586
pyplot.ylabel(self.x_var)
@@ -107,11 +108,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
107108
ax = fig.add_subplot(projection='3d')
108109
for fp_type, points in container.items():
109110
if len(points['x']):
110-
plot_style = plotstyle.plot_schema[fp_type]
111+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
111112
xs = points['p0']
112113
ys = points['p1']
113114
zs = points['x']
114115
plot_style.pop('linestyle')
116+
plot_style['s'] = plot_style.pop('markersize', None)
115117
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)
116118

117119
ax.set_xlabel(self.target_par_names[0])
@@ -299,7 +301,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
299301
pyplot.figure(var)
300302
for fp_type, points in container.items():
301303
if len(points['p']):
302-
plot_style = plotstyle.plot_schema[fp_type]
304+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
303305
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
304306
pyplot.xlabel(self.target_par_names[0])
305307
pyplot.ylabel(var)
@@ -331,11 +333,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
331333
ax = fig.add_subplot(projection='3d')
332334
for fp_type, points in container.items():
333335
if len(points['p0']):
334-
plot_style = plotstyle.plot_schema[fp_type]
336+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
335337
xs = points['p0']
336338
ys = points['p1']
337339
zs = points[var]
338340
plot_style.pop('linestyle')
341+
plot_style['s'] = plot_style.pop('markersize', None)
339342
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)
340343

341344
ax.set_xlabel(self.target_par_names[0])

brainpy/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from jax import vmap
66

7+
from copy import deepcopy
78
import brainpy.math as bm
89
from brainpy import errors, math
910
from brainpy.analysis import stability, plotstyle, constants as C, utils
@@ -107,7 +108,7 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
107108
if with_plot:
108109
for fp_type, points in container.items():
109110
if len(points):
110-
plot_style = plotstyle.plot_schema[fp_type]
111+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
111112
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
112113
pyplot.legend()
113114
if show:
@@ -349,7 +350,7 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
349350
if with_plot:
350351
for fp_type, points in container.items():
351352
if len(points['x']):
352-
plot_style = plotstyle.plot_schema[fp_type]
353+
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
353354
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
354355
pyplot.legend()
355356
if show:
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
import pytest
5+
pytest.skip('Test cannot pass in github action.', allow_module_level=True)
6+
import unittest
7+
8+
import brainpy as bp
9+
import brainpy.math as bm
10+
import matplotlib.pyplot as plt
11+
12+
block = False
13+
14+
15+
class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
16+
def __init__(self, method='exp_auto'):
17+
super(FitzHughNagumoModel, self).__init__()
18+
19+
# parameters
20+
self.a = 0.7
21+
self.b = 0.8
22+
self.tau = 12.5
23+
24+
# variables
25+
self.V = bm.Variable(bm.zeros(1))
26+
self.w = bm.Variable(bm.zeros(1))
27+
self.Iext = bm.Variable(bm.zeros(1))
28+
29+
# functions
30+
def dV(V, t, w, Iext=0.):
31+
dV = V - V * V * V / 3 - w + Iext
32+
return dV
33+
34+
def dw(w, t, V, a=0.7, b=0.8):
35+
dw = (V + a - b * w) / self.tau
36+
return dw
37+
38+
self.int_V = bp.odeint(dV, method=method)
39+
self.int_w = bp.odeint(dw, method=method)
40+
41+
def update(self, tdi):
42+
t, dt = tdi['t'], tdi['dt']
43+
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
44+
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
45+
self.Iext[:] = 0.
46+
47+
48+
class TestBifurcation1D(unittest.TestCase):
49+
def test_bifurcation_1d(self):
50+
bp.math.enable_x64()
51+
52+
@bp.odeint
53+
def int_x(x, t, a=1., b=1.):
54+
return bp.math.sin(a * x) + bp.math.cos(b * x)
55+
56+
pp = bp.analysis.PhasePlane1D(
57+
model=int_x,
58+
target_vars={'x': [-bp.math.pi, bp.math.pi]},
59+
resolutions=0.1
60+
)
61+
pp.plot_vector_field()
62+
pp.plot_fixed_point(show=True)
63+
64+
bf = bp.analysis.Bifurcation1D(
65+
model=int_x,
66+
target_vars={'x': [-bp.math.pi, bp.math.pi]},
67+
target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]},
68+
resolutions={'a': 0.1, 'b': 0.1}
69+
)
70+
bf.plot_bifurcation(show=False)
71+
plt.show(block=block)
72+
plt.close()
73+
bp.math.disable_x64()
74+
75+
def test_bifurcation_2d(self):
76+
bp.math.enable_x64()
77+
78+
model = FitzHughNagumoModel()
79+
bif = bp.analysis.Bifurcation2D(
80+
model=model,
81+
target_vars={'V': [-3., 3.], 'w': [-1, 3.]},
82+
target_pars={'Iext': [0., 1.]},
83+
resolutions={'Iext': 0.1}
84+
)
85+
bif.plot_bifurcation()
86+
bif.plot_limit_cycle_by_sim()
87+
plt.show(block=block)
88+
89+
# bp.math.disable_x64()

brainpy/analysis/lowdim/tests/test_phase_plane.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import unittest
44

55
import brainpy as bp
6+
import matplotlib.pyplot as plt
67

78
block = False
89

910

1011
class TestPhasePlane(unittest.TestCase):
1112
def test_1d(self):
12-
import matplotlib.pyplot as plt
1313
bp.math.enable_x64()
1414

1515
@bp.odeint
@@ -30,8 +30,6 @@ def int_x(x, t, Iext):
3030
bp.math.disable_x64()
3131

3232
def test_2d_decision_making_model(self):
33-
import matplotlib.pyplot as plt
34-
3533
bp.math.enable_x64()
3634
gamma = 0.641 # Saturation factor for gating variable
3735
tau = 0.06 # Synaptic time constant [sec]

brainpy/analysis/plotstyle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)
1717

1818

19-
_markersize = 20
19+
_markersize = 10
2020

2121
plot_schema = {}
2222

brainpy/dyn/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def register_delay(
182182
elif delay.num_delay_step - 1 < max_delay_step:
183183
self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
184184
else:
185-
self.global_delay_data[identifier] = (None, delay_target)
185+
if identifier not in self.global_delay_data:
186+
self.global_delay_data[identifier] = (None, delay_target)
186187
self.register_implicit_nodes(self.local_delay_vars)
187188
return delay_step
188189

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import unittest
4+
5+
import brainpy as bp
6+
7+
8+
class TestDynamicalSystem(unittest.TestCase):
9+
def test_delay(self):
10+
A = bp.neurons.LIF(1)
11+
B = bp.neurons.LIF(1)
12+
C = bp.neurons.LIF(1)
13+
A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1)
14+
A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None)
15+
net = bp.Network(A, B, C, A2B, A2C)
16+
17+
runner = bp.DSRunner(net,)
18+
runner.run(10.)
19+
20+

0 commit comments

Comments
 (0)