Skip to content

Commit 384e5c7

Browse files
committed
fix bugs on slow point analysis
1 parent b97b15b commit 384e5c7

File tree

9 files changed

+84
-55
lines changed

9 files changed

+84
-55
lines changed

brainpy/analysis/highdim/slow_points.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,7 @@ def train(idx):
355355
return loss
356356

357357
def batch_train(start_i, n_batch):
358-
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
359-
return f(bm.arange(start_i, start_i + n_batch))
358+
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))
360359

361360
# Run the optimization
362361
if self.verbose:
@@ -369,7 +368,7 @@ def batch_train(start_i, n_batch):
369368
break
370369
batch_idx_start = oidx * num_batch
371370
start_time = time.time()
372-
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
371+
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
373372
batch_time = time.time() - start_time
374373
opt_losses.append(train_losses)
375374

@@ -722,8 +721,6 @@ def _generate_ds_cell_function(
722721
shared = DotDict(t=t, dt=dt, i=0)
723722

724723
def f_cell(h: Dict):
725-
target.clear_input()
726-
727724
# update target variables
728725
for k, v in self.target_vars.items():
729726
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -735,6 +732,7 @@ def f_cell(h: Dict):
735732
v.value = self.excluded_data[k]
736733

737734
# add inputs
735+
target.clear_input()
738736
if f_input is not None:
739737
f_input(shared)
740738

@@ -743,7 +741,7 @@ def f_cell(h: Dict):
743741
target.update(*args)
744742

745743
# get new states
746-
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
744+
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
747745
for k, v in self.target_vars.items()}
748746
return new_h
749747

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,17 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
354354
if with_return:
355355
return final_fps, final_pars, jacobians
356356

357-
def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
358-
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
357+
def plot_limit_cycle_by_sim(
358+
self,
359+
duration=100,
360+
with_plot: bool = True,
361+
with_return: bool = False,
362+
plot_style: dict = None,
363+
tol: float = 0.001,
364+
show: bool = False,
365+
dt: float = None,
366+
offset: float = 1.
367+
):
359368
global pyplot
360369
if pyplot is None: from matplotlib import pyplot
361370
utils.output('I am plotting the limit cycle ...')
@@ -400,10 +409,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
400409
if len(ps_limit_cycle[0]):
401410
for i, var in enumerate(self.target_var_names):
402411
pyplot.figure(var)
403-
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
404-
**plot_style, label='limit cycle (max)')
405-
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
406-
**plot_style, label='limit cycle (min)')
412+
pyplot.plot(ps_limit_cycle[0],
413+
ps_limit_cycle[1],
414+
vs_limit_cycle[i]['max'],
415+
**plot_style,
416+
label='limit cycle (max)')
417+
pyplot.plot(ps_limit_cycle[0],
418+
ps_limit_cycle[1],
419+
vs_limit_cycle[i]['min'],
420+
**plot_style,
421+
label='limit cycle (min)')
407422
pyplot.legend()
408423

409424
elif len(self.target_par_names) == 1:
@@ -427,8 +442,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
427442

428443

429444
class FastSlow1D(Bifurcation1D):
430-
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
431-
pars_update=None, resolutions=None, options=None):
445+
def __init__(
446+
self,
447+
model,
448+
fast_vars: dict,
449+
slow_vars: dict,
450+
fixed_vars: dict = None,
451+
pars_update: dict = None,
452+
resolutions=None,
453+
options: dict = None
454+
):
432455
super(FastSlow1D, self).__init__(model=model,
433456
target_pars=slow_vars,
434457
target_vars=fast_vars,
@@ -510,8 +533,16 @@ def plot_trajectory(self, initials, duration, plot_durations=None,
510533

511534

512535
class FastSlow2D(Bifurcation2D):
513-
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
514-
pars_update=None, resolutions=0.1, options=None):
536+
def __init__(
537+
self,
538+
model,
539+
fast_vars: dict,
540+
slow_vars: dict,
541+
fixed_vars: dict = None,
542+
pars_update: dict = None,
543+
resolutions=0.1,
544+
options: dict = None
545+
):
515546
super(FastSlow2D, self).__init__(model=model,
516547
target_pars=slow_vars,
517548
target_vars=fast_vars,

brainpy/analysis/utils/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,14 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
112112

113113
# variables
114114
assert isinstance(initial_vars, dict)
115-
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_))
115+
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()))
116116
for k, v in initial_vars.items()}
117117
self.register_implicit_vars(initial_vars)
118118

119119
# parameters
120120
pars = dict() if pars is None else pars
121121
assert isinstance(pars, dict)
122-
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)
122+
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())
123123
for k, v in pars.items()]
124124

125125
# integrals
@@ -128,7 +128,8 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
128128
# runner
129129
self.runner = DSRunner(self,
130130
monitors=list(initial_vars.keys()),
131-
dyn_vars=self.vars().unique(), dt=dt,
131+
dyn_vars=self.vars().unique(),
132+
dt=dt,
132133
progress_bar=False)
133134

134135
def update(self, sha):

brainpy/dyn/neurons/biological_models.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,20 +244,17 @@ def __init__(
244244

245245
# variables
246246
self.V = variable(self._V_initializer, mode, self.varshape)
247-
if self._m_initializer is None:
248-
self.m = bm.Variable(self.m_inf(self.V.value))
249-
else:
250-
self.m = variable(self._m_initializer, mode, self.varshape)
251-
if self._h_initializer is None:
252-
self.h = bm.Variable(self.h_inf(self.V.value))
253-
else:
254-
self.h = variable(self._h_initializer, mode, self.varshape)
255-
if self._n_initializer is None:
256-
self.n = bm.Variable(self.n_inf(self.V.value))
257-
else:
258-
self.n = variable(self._n_initializer, mode, self.varshape)
259-
self.input = variable(bm.zeros, mode, self.varshape)
247+
self.m = (bm.Variable(self.m_inf(self.V.value))
248+
if m_initializer is None else
249+
variable(self._m_initializer, mode, self.varshape))
250+
self.h = (bm.Variable(self.h_inf(self.V.value))
251+
if h_initializer is None else
252+
variable(self._h_initializer, mode, self.varshape))
253+
self.n = (bm.Variable(self.n_inf(self.V.value))
254+
if n_initializer is None else
255+
variable(self._n_initializer, mode, self.varshape))
260256
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
257+
self.input = variable(bm.zeros, mode, self.varshape)
261258

262259
# integral
263260
if self.noise is None:
@@ -309,7 +306,7 @@ def dV(self, V, t, m, h, n, I_ext):
309306

310307
@property
311308
def derivative(self):
312-
return JointEq([self.dV, self.dm, self.dh, self.dn])
309+
return JointEq(self.dV, self.dm, self.dh, self.dn)
313310

314311
def update(self, tdi, x=None):
315312
t, dt = tdi['t'], tdi['dt']

brainpy/inputs/currents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,
307307

308308
def _f(t):
309309
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n)
310+
return x.value
310311

311-
f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
312-
noises = f(jnp.arange(t_start, t_end, dt))
312+
noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt))
313313

314314
t_end = duration if t_end is None else t_end
315315
i_start = int(t_start / dt)

brainpy/math/operators/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ def _check_brainpylib(ops_name):
1717
raise PackageMissingError(
1818
f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
1919
f'Please install it through:\n\n'
20-
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U'
20+
f'>>> pip install brainpylib=={_BRAINPYLIB_MINIMAL_VERSION}\n'
21+
f'>>> # or \n'
22+
f'>>> pip install brainpylib -U'
2123
)
2224
else:
2325
raise PackageMissingError(
2426
f'"brainpylib" must be installed when the user '
2527
f'wants to use "{ops_name}" operator. \n'
2628
f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n'
27-
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}'
29+
f'>>> pip install brainpylib'
2830
)

brainpy/train/back_propagation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def f_train(self, shared_args=None) -> Callable:
520520
shared_args_str = serialize_kwargs(shared_args)
521521
if shared_args_str not in self._f_train_compiled:
522522

523-
def train_step(x):
523+
def train_step(*x):
524524
# t, i, input_, target_ = x
525525
res = self.f_grad(shared_args)(*x)
526526
self.optimizer.update(res[0])
@@ -529,8 +529,7 @@ def train_step(x):
529529
if self.jit[c.FIT_PHASE]:
530530
dyn_vars = self.target.vars()
531531
dyn_vars.update(self.dyn_vars)
532-
f = bm.make_loop(train_step, dyn_vars=dyn_vars.unique(), has_return=True)
533-
run_func = lambda all_inputs: f(all_inputs)[1]
532+
run_func = lambda all_inputs: bm.for_loop(train_step, dyn_vars.unique(), all_inputs)
534533

535534
else:
536535
def run_func(xs):
@@ -541,7 +540,7 @@ def run_func(xs):
541540
x = tree_map(lambda x: x[i], inputs, is_leaf=_is_jax_array)
542541
y = tree_map(lambda x: x[i], targets, is_leaf=_is_jax_array)
543542
# step at the i
544-
loss = train_step((times[i], indices[i], x, y))
543+
loss = train_step(times[i], indices[i], x, y)
545544
# append output and monitor
546545
losses.append(loss)
547546
return bm.asarray(losses)

examples/analysis/4d_HH_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import brainpy as bp
44
import brainpy.math as bm
55

6+
67
I = 5.
78
model = bp.dyn.neurons.HH(1)
8-
runner = bp.dyn.DSRunner(model, inputs=('input', I), monitors=['V'])
9+
runner = bp.dyn.DSRunner(model, inputs=(model.input, I), monitors=['V'])
910
runner.run(100)
1011
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True)
1112

1213
# analysis
14+
bm.enable_x64()
1315
model = bp.dyn.neurons.HH(1, method='euler')
1416
finder = bp.analysis.SlowPointFinder(
1517
model,
@@ -18,13 +20,14 @@
1820
'm': model.m,
1921
'h': model.h,
2022
'n': model.n},
21-
dt=1.
23+
dt=10.
2224
)
23-
candidates = {'V': bm.random.normal(0., 5., (1000, model.num)) - 50.,
25+
finder.find_fps_with_opt_solver(
26+
candidates={'V': bm.random.normal(0., 10., (1000, model.num)) - 50.,
2427
'm': bm.random.random((1000, model.num)),
2528
'h': bm.random.random((1000, model.num)),
2629
'n': bm.random.random((1000, model.num))}
27-
finder.find_fps_with_opt_solver(candidates=candidates)
30+
)
2831
finder.filter_loss(1e-7)
2932
finder.keep_unique(tolerance=1e-1)
3033
print('fixed_points: ', finder.fixed_points)

examples/training/Song_2016_EI_RNN.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,17 @@ def __init__(self, num_input, num_hidden, num_output, num_batch,
128128
self.mask = bm.asarray(mask, dtype=bm.dftype())
129129

130130
# input weight
131-
self.w_ir = bm.TrainVar(w_ir(num_input, num_hidden))
131+
self.w_ir = bm.TrainVar(w_ir((num_input, num_hidden)))
132132

133133
# recurrent weight
134134
bound = 1 / num_hidden ** 0.5
135-
self.w_rr = bm.TrainVar(w_rr(num_hidden, num_hidden))
135+
self.w_rr = bm.TrainVar(w_rr((num_hidden, num_hidden)))
136136
self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size)
137137
self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))
138138

139139
# readout weight
140140
bound = 1 / self.e_size ** 0.5
141-
self.w_ro = bm.TrainVar(w_ro(self.e_size, num_output))
141+
self.w_ro = bm.TrainVar(w_ro((self.e_size, num_output)))
142142
self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))
143143

144144
# variables
@@ -158,15 +158,13 @@ def make_update(self, h: bm.JaxArray, o: bm.JaxArray):
158158
def f(x):
159159
h.value = self.cell(x, h.value)
160160
o.value = self.readout(h.value[:, :self.e_size])
161+
return h.value, o.value
161162

162163
return f
163164

164165
def predict(self, xs):
165166
self.h[:] = 0.
166-
f = bm.make_loop(self.make_update(self.h, self.o),
167-
dyn_vars=self.vars(),
168-
out_vars=[self.h, self.o])
169-
return f(xs)
167+
return bm.for_loop(self.make_update(self.h, self.o), self.vars(), xs)
170168

171169
def loss(self, xs, ys):
172170
hs, os = self.predict(xs)
@@ -247,7 +245,7 @@ def train(xs, ys):
247245
rnn_activity, action_pred = predict(inputs)
248246

249247
# Compute performance
250-
action_pred = action_pred.numpy()
248+
action_pred = bm.as_numpy(action_pred)
251249
choice = np.argmax(action_pred[-1, 0, :])
252250
correct = choice == gt[-1]
253251

@@ -257,7 +255,7 @@ def train(xs, ys):
257255
trial_infos[i] = trial_info
258256

259257
# Log stimulus period activity
260-
rnn_activity = rnn_activity.numpy()[:, 0, :]
258+
rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :]
261259
activity_dict[i] = rnn_activity
262260

263261
# Compute stimulus selectivity for all units
@@ -312,7 +310,7 @@ def train(xs, ys):
312310
plt.show()
313311

314312
# %%
315-
W = (bm.abs(net.w_rr) * net.mask).numpy()
313+
W = bm.as_numpy(bm.abs(net.w_rr) * net.mask)
316314
# Sort by selectivity
317315
W = W[:, ind_sort][ind_sort, :]
318316
wlim = np.max(np.abs(W))

0 commit comments

Comments
 (0)