Skip to content

Commit 34d8764

Browse files
authored
Merge pull request #422 from brainpy/updates
upgrade Runner and Trainer for new style of ``DynamicalSystem.update()`` function
2 parents f36dd10 + 60eebd3 commit 34d8764

36 files changed

+528
-723
lines changed

.github/workflows/CI-models.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
strategy:
118118
fail-fast: false
119119
matrix:
120-
python-version: ["3.8", "3.9", "3.10", "3.11"]
120+
python-version: ["3.9", "3.10", "3.11"]
121121

122122
steps:
123123
- uses: actions/checkout@v2
@@ -128,8 +128,6 @@ jobs:
128128
- name: Install dependencies
129129
run: |
130130
python -m pip install numpy>=1.21.0
131-
python -m pip install "jaxlib==0.4.10" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
132-
python -m pip install jax==0.4.10
133131
python -m pip install -r requirements-dev.txt
134132
python -m pip install tqdm brainpylib
135133
pip uninstall brainpy -y

brainpy/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
Network = DynSysGroup
6666
# delays
6767
from brainpy._src.delay import (
68-
VariDelay as VariDelay,
68+
VarDelay as VarDelay,
6969
)
7070

7171
# building blocks
@@ -129,12 +129,15 @@
129129
from brainpy._add_deprecations import deprecation_getattr2
130130

131131
__deprecations = {
132+
'Module': ('brainpy.Module', 'brainpy.DynamicalSystem', DynamicalSystem),
133+
'Channel': ('brainpy.Channel', 'brainpy.dyn.IonChannel', dyn.IonChannel),
134+
'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
132135
'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
136+
133137
'optimizers': ('brainpy.optimizers', 'brainpy.optim', optim),
134138
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
135139
'SynSTP': ('brainpy.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
136140
'SynOut': ('brainpy.SynOut', 'brainpy.synapses.SynOut', synapses.SynOut),
137-
'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
138141
'TwoEndConn': ('brainpy.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
139142
'CondNeuGroup': ('brainpy.CondNeuGroup', 'brainpy.syn.CondNeuGroup', dyn.CondNeuGroup),
140143
}

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
import inspect
34
import math
45
import time
6+
import warnings
57
from typing import Callable, Union, Dict, Sequence, Tuple
68

79
import jax.numpy as jnp
@@ -14,10 +16,12 @@
1416
from brainpy import optim, losses
1517
from brainpy._src.analysis import utils, base, constants
1618
from brainpy._src.dynsys import DynamicalSystem
19+
from brainpy._src.context import share
1720
from brainpy._src.runners import check_and_format_inputs, _f_ops
18-
from brainpy._src.tools.dicts import DotDict
1921
from brainpy.errors import AnalyzerError, UnsupportedError
2022
from brainpy.types import ArrayType
23+
from brainpy._src.deprecations import _input_deprecate_msg
24+
2125

2226
__all__ = [
2327
'SlowPointFinder',
@@ -123,7 +127,7 @@ def __init__(
123127
f_loss_batch: Callable = None,
124128
fun_inputs: Callable = None,
125129
):
126-
super(SlowPointFinder, self).__init__()
130+
super().__init__()
127131

128132
# static arguments
129133
if not isinstance(args, tuple):
@@ -514,7 +518,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
514518
# Compute pairwise distances between all fixed points.
515519
distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps))
516520

517-
# Find second smallest element in each column of the pairwise distance matrix.
521+
# Find the second smallest element in each column of the pairwise distance matrix.
518522
# This corresponds to the closest neighbor for each fixed point.
519523
closest_neighbor = np.partition(distances, kth=1, axis=0)[1]
520524

@@ -636,11 +640,16 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
636640
'L': L})
637641
return decompositions
638642

639-
def _step_func_input(self, shared):
643+
def _step_func_input(self):
640644
if self._inputs is None:
641645
return
642646
elif callable(self._inputs):
643-
self._inputs(shared)
647+
try:
648+
ba = inspect.signature(self._inputs).bind(dict())
649+
self._inputs(share.get_shargs())
650+
warnings.warn(_input_deprecate_msg, UserWarning)
651+
except TypeError:
652+
self._inputs()
644653
else:
645654
for ops, values in self._inputs['fixed'].items():
646655
for var, data in values:
@@ -650,7 +659,7 @@ def _step_func_input(self, shared):
650659
raise UnsupportedError
651660
for ops, values in self._inputs['functional'].items():
652661
for var, data in values:
653-
_f_ops(ops, var, data(shared))
662+
_f_ops(ops, var, data(share.get_shargs()))
654663
for ops, values in self._inputs['iterated'].items():
655664
if len(values) > 0:
656665
raise UnsupportedError
@@ -732,9 +741,10 @@ def _generate_ds_cell_function(
732741
):
733742
if dt is None: dt = bm.get_dt()
734743
if t is None: t = 0.
735-
shared = DotDict(t=t, dt=dt, i=0)
736744

737745
def f_cell(h: Dict):
746+
share.save(t=t, i=0, dt=dt)
747+
738748
# update target variables
739749
for k, v in self.target_vars.items():
740750
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -747,11 +757,10 @@ def f_cell(h: Dict):
747757

748758
# add inputs
749759
target.clear_input()
750-
self._step_func_input(shared)
760+
self._step_func_input()
751761

752762
# call update functions
753-
args = (shared,) + self.args
754-
target(*args)
763+
target(*self.args)
755764

756765
# get new states
757766
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))

brainpy/_src/analysis/lowdim/lowdim_analyzer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(
9999
raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.')
100100
value = self.target_vars[key]
101101
if value[0] > value[1]:
102-
raise errors.AnalyzerError(f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
102+
raise errors.AnalyzerError(
103+
f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
103104

104105
# fixed variables
105106
# ----------------
@@ -246,7 +247,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
246247
"""
247248

248249
def __init__(self, *args, **kwargs):
249-
super(Num1DAnalyzer, self).__init__(*args, **kwargs)
250+
super().__init__(*args, **kwargs)
250251
self.x_var = self.target_var_names[0]
251252
if len(self.target_vars) < 1:
252253
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
@@ -407,7 +408,7 @@ class Num2DAnalyzer(Num1DAnalyzer):
407408
"""
408409

409410
def __init__(self, *args, **kwargs):
410-
super(Num2DAnalyzer, self).__init__(*args, **kwargs)
411+
super().__init__(*args, **kwargs)
411412
if len(self.target_vars) < 2:
412413
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
413414
f'with >= 2 variables. But we got {len(self.target_vars)} '
@@ -1028,7 +1029,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
10281029

10291030
class Num3DAnalyzer(Num2DAnalyzer):
10301031
def __init__(self, *args, **kwargs):
1031-
super(Num3DAnalyzer, self).__init__(*args, **kwargs)
1032+
super().__init__(*args, **kwargs)
10321033
if len(self.target_vars) < 3:
10331034
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
10341035
f'with >= 3 variables. But we got {len(self.target_vars)} '
@@ -1045,7 +1046,3 @@ def F_fz(self):
10451046
f = partial(f, **(self.pars_update + self.fixed_vars))
10461047
self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device)
10471048
return self.analyzed_results[C.F_fz]
1048-
1049-
def fz_signs(self, pars=(), cache=False):
1050-
xyz = tuple(self.resolutions.values())
1051-
return utils.get_sign2(self.F_fz, *xyz, args=pars)

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ class Bifurcation1D(Num1DAnalyzer):
3131

3232
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
3333
pars_update=None, resolutions=None, options=None):
34-
super(Bifurcation1D, self).__init__(model=model,
35-
target_pars=target_pars,
36-
target_vars=target_vars,
37-
fixed_vars=fixed_vars,
38-
pars_update=pars_update,
39-
resolutions=resolutions,
40-
options=options)
34+
super().__init__(model=model,
35+
target_pars=target_pars,
36+
target_vars=target_vars,
37+
fixed_vars=fixed_vars,
38+
pars_update=pars_update,
39+
resolutions=resolutions,
40+
options=options)
4141

4242
if len(self.target_pars) == 0:
4343
raise ValueError
@@ -146,13 +146,13 @@ class Bifurcation2D(Num2DAnalyzer):
146146

147147
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
148148
pars_update=None, resolutions=None, options=None):
149-
super(Bifurcation2D, self).__init__(model=model,
150-
target_pars=target_pars,
151-
target_vars=target_vars,
152-
fixed_vars=fixed_vars,
153-
pars_update=pars_update,
154-
resolutions=resolutions,
155-
options=options)
149+
super().__init__(model=model,
150+
target_pars=target_pars,
151+
target_vars=target_vars,
152+
fixed_vars=fixed_vars,
153+
pars_update=pars_update,
154+
resolutions=resolutions,
155+
options=options)
156156

157157
if len(self.target_pars) == 0:
158158
raise ValueError
@@ -458,13 +458,13 @@ def __init__(
458458
resolutions=None,
459459
options: dict = None
460460
):
461-
super(FastSlow1D, self).__init__(model=model,
462-
target_pars=slow_vars,
463-
target_vars=fast_vars,
464-
fixed_vars=fixed_vars,
465-
pars_update=pars_update,
466-
resolutions=resolutions,
467-
options=options)
461+
super().__init__(model=model,
462+
target_pars=slow_vars,
463+
target_vars=fast_vars,
464+
fixed_vars=fixed_vars,
465+
pars_update=pars_update,
466+
resolutions=resolutions,
467+
options=options)
468468

469469
# standard integrators
470470
self._std_integrators = dict()
@@ -549,13 +549,13 @@ def __init__(
549549
resolutions=0.1,
550550
options: dict = None
551551
):
552-
super(FastSlow2D, self).__init__(model=model,
553-
target_pars=slow_vars,
554-
target_vars=fast_vars,
555-
fixed_vars=fixed_vars,
556-
pars_update=pars_update,
557-
resolutions=resolutions,
558-
options=options)
552+
super().__init__(model=model,
553+
target_pars=slow_vars,
554+
target_vars=fast_vars,
555+
fixed_vars=fixed_vars,
556+
pars_update=pars_update,
557+
resolutions=resolutions,
558+
options=options)
559559
# standard integrators
560560
self._std_integrators = dict()
561561
for key, intg in self.model.name2integral.items():

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def __init__(self,
5555
if (target_pars is not None) and len(target_pars) > 0:
5656
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
5757
f'While we detect "target_pars={target_pars}".')
58-
super(PhasePlane1D, self).__init__(model=model,
59-
target_vars=target_vars,
60-
fixed_vars=fixed_vars,
61-
target_pars=target_pars,
62-
pars_update=pars_update,
63-
resolutions=resolutions,
64-
**kwargs)
58+
super().__init__(model=model,
59+
target_vars=target_vars,
60+
fixed_vars=fixed_vars,
61+
target_pars=target_pars,
62+
pars_update=pars_update,
63+
resolutions=resolutions,
64+
**kwargs)
6565
# utils.output(f'I am {PhasePlane1D.__name__}.')
6666

6767
def plot_vector_field(self, show=False, with_plot=True, with_return=False):
@@ -150,13 +150,13 @@ def __init__(self,
150150
if (target_pars is not None) and len(target_pars) > 0:
151151
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
152152
f'While we detect "target_pars={target_pars}".')
153-
super(PhasePlane2D, self).__init__(model=model,
154-
target_vars=target_vars,
155-
fixed_vars=fixed_vars,
156-
target_pars=target_pars,
157-
pars_update=pars_update,
158-
resolutions=resolutions,
159-
**kwargs)
153+
super().__init__(model=model,
154+
target_vars=target_vars,
155+
fixed_vars=fixed_vars,
156+
target_pars=target_pars,
157+
pars_update=pars_update,
158+
resolutions=resolutions,
159+
**kwargs)
160160

161161
@property
162162
def F_vmap_brentq_fy(self):
@@ -251,7 +251,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
251251
if with_plot:
252252
if x_style is None:
253253
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
254-
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
254+
line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple()
255255
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")
256256

257257
# Nullcline of the y variable
@@ -263,7 +263,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
263263
if with_plot:
264264
if y_style is None:
265265
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
266-
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
266+
line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple()
267267
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")
268268

269269
if with_plot:

brainpy/_src/analysis/utils/model.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from brainpy._src.math.environment import get_float
66
from brainpy._src.math.interoperability import as_jax
77
from brainpy._src.dynsys import DynamicalSystem
8+
from brainpy._src.context import share
89
from brainpy._src.runners import DSRunner
910
from brainpy._src.integrators.base import Integrator
1011
from brainpy._src.integrators.joint_eq import JointEq
@@ -126,16 +127,12 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
126127
self.integrals = integrals
127128

128129
# runner
129-
self.runner = DSRunner(self,
130-
monitors=list(initial_vars.keys()),
131-
dyn_vars=self.vars().unique(),
132-
dt=dt,
133-
progress_bar=False)
130+
self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)
134131

135-
def update(self, sha):
132+
def update(self):
136133
all_vars = list(self.implicit_vars.values())
137134
for key, intg in self.integrals.items():
138-
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=sha['dt']))
135+
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
139136

140137
def __getattr__(self, item):
141138
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')

brainpy/_src/delay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
__all__ = [
2323
'Delay',
24-
'VariDelay',
24+
'VarDelay',
2525
'DataDelay',
2626
'DelayAccess',
2727
]
@@ -432,7 +432,7 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
432432
return sharding
433433

434434

435-
class VariDelay(Delay):
435+
class VarDelay(Delay):
436436
"""Generate Delays for the given :py:class:`~.Variable` instance.
437437
438438
The data in this delay variable is arranged as::
@@ -690,7 +690,7 @@ def _init_data(self, length: int, batch_size: int = None):
690690
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
691691

692692

693-
class DataDelay(VariDelay):
693+
class DataDelay(VarDelay):
694694
not_desc_params = ('time', 'entries')
695695

696696
def __init__(

0 commit comments

Comments
 (0)