Skip to content

Commit cce047c

Browse files
committed
update examples
1 parent 329b6e7 commit cce047c

File tree

12 files changed

+83
-102
lines changed

12 files changed

+83
-102
lines changed

examples/dynamics_analysis/2d_fitzhugh_nagumo_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def dw(w, t, V, a=0.7, b=0.8):
3333
self.int_V = bp.odeint(dV, method=method)
3434
self.int_w = bp.odeint(dw, method=method)
3535

36-
def update(self, tdi):
37-
t, dt = tdi['t'], tdi['dt']
36+
def update(self):
37+
t = bp.share['t']
38+
dt = bp.share['dt']
3839
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
3940
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
4041
self.Iext[:] = 0.

examples/dynamics_analysis/2d_mean_field_QIF.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class MeanFieldQIF(bp.DynamicalSystem):
1414
"""
1515

1616
def __init__(self, method='exp_auto'):
17-
super(MeanFieldQIF, self).__init__()
17+
super().__init__()
1818

1919
# parameters
2020
self.tau = 1. # the population time constant
@@ -38,8 +38,9 @@ def dv(v, t, r, Iext=0., eta=-5.0):
3838
self.int_r = bp.odeint(dr, method=method)
3939
self.int_v = bp.odeint(dv, method=method)
4040

41-
def update(self, tdi):
42-
t, dt = tdi['t'], tdi['dt']
41+
def update(self):
42+
t = bp.share['t']
43+
dt = bp.share['dt']
4344
self.r.value = self.int_r(self.r, t, self.v, self.delta, dt)
4445
self.v.value = self.int_v(self.v, t, self.r, self.Iext, self.eta, dt)
4546
self.Iext[:] = 0.

examples/dynamics_analysis/3d_reduced_trn_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
bp.math.set_platform('cpu')
88

99

10-
class ReducedTRNModel(bp.NeuDyn):
10+
class ReducedTRNModel(bp.dyn.NeuDyn):
1111
def __init__(self, size, name=None, T=36., method='rk4'):
12-
super(ReducedTRNModel, self).__init__(size=size, name=name)
12+
super().__init__(size=size, name=name)
1313

1414
self.IT_th = -3.
1515
self.b = 0.5

examples/dynamics_analysis/highdim_RNN_Analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
w_rr=bp.init.KaimingNormal(scale=1.),
2727
w_ro=bp.init.KaimingNormal(scale=1.)
2828
):
29-
super(RNNNet, self).__init__()
29+
super().__init__()
3030

3131
self.tau = 100
3232
self.num_input = num_input
@@ -64,7 +64,7 @@ def cell(self, x, h):
6464
def readout(self, h):
6565
return h @ self.w_ro + self.b_ro
6666

67-
def update(self, sha, x):
67+
def update(self, x):
6868
self.h.value = self.cell(x, self.h.value)
6969
return self.readout(self.h.value)
7070

examples/dynamics_simulation/COBA-v2.py renamed to examples/dynamics_simulation/COBA.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,6 @@ def __init__(self, scale=1.0, method='exp_auto'):
140140
# bm.set_host_device_count(num_device)
141141
# bm.sharding.set(mesh_axes=(bp.dyn.PNEU_AXIS,), mesh_shape=(num_device, ))
142142

143-
def run3():
144-
net = EICOBA_PreAlign(3200, 800)
145-
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
146-
print(runner.run(100., eval_time=True))
147-
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
148-
149143

150144
def run1():
151145
with bm.environment(mode=bm.BatchingMode(10)):
@@ -167,7 +161,23 @@ def run2():
167161
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
168162

169163

164+
def run3():
165+
net = EICOBA_PreAlign(3200, 800)
166+
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
167+
print(runner.run(100., eval_time=True))
168+
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
169+
170+
171+
172+
def run4():
173+
net = EICOBA_PostAlign(3200, 800)
174+
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
175+
print(runner.run(100., eval_time=True))
176+
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
177+
178+
170179
if __name__ == '__main__':
171-
# run1()
180+
run1()
172181
run2()
173-
# run3()
182+
run3()
183+
run4()

examples/dynamics_simulation/hh_model.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,27 @@
1111

1212
class HH(bp.dyn.CondNeuGroup):
1313
def __init__(self, size):
14-
super().__init__(size, keep_size=True)
14+
super().__init__(size)
1515

16-
self.INa = bp.channels.INa_HH1952(size, keep_size=True)
17-
self.IK = bp.channels.IK_HH1952(size, keep_size=True)
18-
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03, keep_size=True)
16+
self.INa = bp.channels.INa_HH1952(size)
17+
self.IK = bp.channels.IK_HH1952(size)
18+
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
1919

2020

2121
class HHv2(bp.dyn.CondNeuGroupLTC):
2222
def __init__(self, size):
23-
super().__init__(size, keep_size=True)
23+
super().__init__(size)
2424

2525
self.Na = bp.dyn.SodiumFixed(size, E=50.)
26-
self.Na.add(ina=bp.dyn.INa_HH1952v2(size, keep_size=True))
26+
self.Na.add_elem(ina=bp.dyn.INa_HH1952v2(size))
2727

2828
self.K = bp.dyn.PotassiumFixed(size, E=50.)
29-
self.K.add(ik=bp.dyn.IK_HH1952v2(size, keep_size=True))
30-
31-
self.IL = bp.dyn.IL(size, E=-54.387, g_max=0.03, keep_size=True)
32-
33-
self.KNa = bp.dyn.mixs(self.Na, self.K)
34-
self.KNa.add()
35-
36-
37-
29+
self.K.add_elem(ik=bp.dyn.IK_HH1952v2(size))
3830

31+
self.IL = bp.dyn.IL(size, E=-54.387, g_max=0.03)
3932

33+
self.KNa = bp.dyn.MixIons(self.Na, self.K)
34+
self.KNa.add_elem()
4035

4136

4237
# hh = HH(1)
@@ -52,26 +47,3 @@ def __init__(self, size):
5247
#
5348
# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
5449

55-
56-
hh = HH((20, 10000))
57-
variables = hh.vars().unique()
58-
59-
60-
iis = np.arange(1000000000)
61-
62-
def f(i):
63-
bp.share.save(i=i, t=i * bm.get_dt(), dt=bm.get_dt())
64-
hh(5.)
65-
66-
67-
@pmap
68-
def run(vars):
69-
for v, d in vars.items():
70-
variables[v]._value = d
71-
bm.for_loop(f, bm.arange(1000000000))
72-
print('Compiling End')
73-
return hh.spike
74-
75-
76-
r = run(variables.dict())
77-
print(r.shape)

examples/dynamics_simulation/multi_scale_COBAHH.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@
77

88
import brainpy as bp
99
import brainpy.math as bm
10-
from brainpy.channels import INa_TM1991, IL
11-
from brainpy.synapses import Exponential, COBA
1210
from brainpy.connect import FixedProb
1311
from jax import vmap
1412

15-
comp_method = 'sparse'
1613

1714

1815
area_names = ['V1', 'V2', 'V4', 'TEO', 'TEpd']
@@ -47,8 +44,8 @@ class HH(bp.CondNeuGroup):
4744
def __init__(self, size):
4845
super(HH, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
4946
self.IK = IK(size, g_max=30., V_sh=-63.)
50-
self.INa = INa_TM1991(size, g_max=100., V_sh=-63.)
51-
self.IL = IL(size, E=-60., g_max=0.05)
47+
self.INa = bp.dyn.INa_TM1991(size, g_max=100., V_sh=-63.)
48+
self.IL = bp.dyn.IL(size, E=-60., g_max=0.05)
5249

5350

5451
class Network(bp.Network):

examples/dynamics_simulation/whole_brain_simulation_with_fhn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def bifurcation_analysis():
2121
pp.show_figure()
2222

2323

24-
class Network(bp.Network):
24+
class Network(bp.DynSysGroup):
2525
def __init__(self, signal_speed=20.):
2626
super(Network, self).__init__()
2727

@@ -36,12 +36,12 @@ def __init__(self, signal_speed=20.):
3636
delay_mat = bm.asarray(delay_mat)
3737
bm.fill_diagonal(delay_mat, 0)
3838

39-
self.fhn = bp.rates.FHN(
39+
self.fhn = bp.dyn.FHN(
4040
80,
4141
x_ou_sigma=0.01,
4242
y_ou_sigma=0.01,
4343
)
44-
self.coupling = bp.synapses.DiffusiveCoupling(
44+
self.coupling = bp.dyn.DiffusiveCoupling(
4545
self.fhn.x,
4646
self.fhn.x,
4747
var_to_output=self.fhn.input,
@@ -95,5 +95,5 @@ def net_analysis():
9595

9696
if __name__ == '__main__':
9797
# bifurcation_analysis()
98-
# net_simulation()
99-
net_analysis()
98+
net_simulation()
99+
# net_analysis()

examples/dynamics_simulation/whole_brain_simulation_with_sl_oscillator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def bifurcation_analysis():
13-
model = bp.rates.StuartLandauOscillator(1, method='exp_auto')
13+
model = bp.dyn.StuartLandauOscillator(1, method='exp_auto')
1414
pp = bp.analysis.Bifurcation2D(
1515
model,
1616
target_vars={'x': [-2, 2], 'y': [-2, 2]},
@@ -22,7 +22,7 @@ def bifurcation_analysis():
2222
pp.show_figure()
2323

2424

25-
class Network(bp.Network):
25+
class Network(bp.DynSysGroup):
2626
def __init__(self, noise=0.14):
2727
super(Network, self).__init__()
2828

@@ -35,8 +35,8 @@ def __init__(self, noise=0.14):
3535
bm.fill_diagonal(conn_mat, 0)
3636
gc = 0.6 # global coupling strength
3737

38-
self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=noise, y_ou_sigma=noise)
39-
self.coupling = bp.synapses.DiffusiveCoupling(
38+
self.sl = bp.dyn.StuartLandauOscillator(80, x_ou_sigma=noise, y_ou_sigma=noise)
39+
self.coupling = bp.dyn.DiffusiveCoupling(
4040
self.sl.x, self.sl.x,
4141
var_to_output=self.sl.input,
4242
conn_mat=conn_mat * gc
@@ -87,6 +87,6 @@ def net_analysis():
8787

8888

8989
if __name__ == '__main__':
90-
bifurcation_analysis()
90+
# bifurcation_analysis()
9191
simulation()
92-
net_analysis()
92+
# net_analysis()

examples/dynamics_training/Song_2016_EI_RNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
w_rr=bp.init.KaimingUniform(scale=1.),
2828
w_ro=bp.init.KaimingUniform(scale=1.)
2929
):
30-
super(EI_RNN, self).__init__()
30+
super().__init__()
3131

3232
# parameters
3333
self.tau = 100

0 commit comments

Comments
 (0)