Skip to content

Commit 7c56adf

Browse files
committed
upgrade brainpy.analysis for new version of DynamicalSystem
1 parent cce047c commit 7c56adf

File tree

5 files changed

+64
-61
lines changed

5 files changed

+64
-61
lines changed

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
import inspect
34
import math
45
import time
6+
import warnings
57
from typing import Callable, Union, Dict, Sequence, Tuple
68

79
import jax.numpy as jnp
@@ -18,6 +20,8 @@
1820
from brainpy._src.runners import check_and_format_inputs, _f_ops
1921
from brainpy.errors import AnalyzerError, UnsupportedError
2022
from brainpy.types import ArrayType
23+
from brainpy._src.deprecations import _input_deprecate_msg
24+
2125

2226
__all__ = [
2327
'SlowPointFinder',
@@ -514,7 +518,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
514518
# Compute pairwise distances between all fixed points.
515519
distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps))
516520

517-
# Find second smallest element in each column of the pairwise distance matrix.
521+
# Find the second smallest element in each column of the pairwise distance matrix.
518522
# This corresponds to the closest neighbor for each fixed point.
519523
closest_neighbor = np.partition(distances, kth=1, axis=0)[1]
520524

@@ -640,7 +644,12 @@ def _step_func_input(self):
640644
if self._inputs is None:
641645
return
642646
elif callable(self._inputs):
643-
self._inputs(share.get_shargs())
647+
try:
648+
ba = inspect.signature(self._inputs).bind(dict())
649+
self._inputs(share.get_shargs())
650+
warnings.warn(_input_deprecate_msg, UserWarning)
651+
except TypeError:
652+
self._inputs()
644653
else:
645654
for ops, values in self._inputs['fixed'].items():
646655
for var, data in values:

brainpy/_src/analysis/lowdim/lowdim_analyzer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(
9999
raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.')
100100
value = self.target_vars[key]
101101
if value[0] > value[1]:
102-
raise errors.AnalyzerError(f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
102+
raise errors.AnalyzerError(
103+
f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
103104

104105
# fixed variables
105106
# ----------------
@@ -246,7 +247,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
246247
"""
247248

248249
def __init__(self, *args, **kwargs):
249-
super(Num1DAnalyzer, self).__init__(*args, **kwargs)
250+
super().__init__(*args, **kwargs)
250251
self.x_var = self.target_var_names[0]
251252
if len(self.target_vars) < 1:
252253
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
@@ -407,7 +408,7 @@ class Num2DAnalyzer(Num1DAnalyzer):
407408
"""
408409

409410
def __init__(self, *args, **kwargs):
410-
super(Num2DAnalyzer, self).__init__(*args, **kwargs)
411+
super().__init__(*args, **kwargs)
411412
if len(self.target_vars) < 2:
412413
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
413414
f'with >= 2 variables. But we got {len(self.target_vars)} '
@@ -1028,7 +1029,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
10281029

10291030
class Num3DAnalyzer(Num2DAnalyzer):
10301031
def __init__(self, *args, **kwargs):
1031-
super(Num3DAnalyzer, self).__init__(*args, **kwargs)
1032+
super().__init__(*args, **kwargs)
10321033
if len(self.target_vars) < 3:
10331034
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
10341035
f'with >= 3 variables. But we got {len(self.target_vars)} '
@@ -1045,7 +1046,3 @@ def F_fz(self):
10451046
f = partial(f, **(self.pars_update + self.fixed_vars))
10461047
self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device)
10471048
return self.analyzed_results[C.F_fz]
1048-
1049-
def fz_signs(self, pars=(), cache=False):
1050-
xyz = tuple(self.resolutions.values())
1051-
return utils.get_sign2(self.F_fz, *xyz, args=pars)

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ class Bifurcation1D(Num1DAnalyzer):
3131

3232
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
3333
pars_update=None, resolutions=None, options=None):
34-
super(Bifurcation1D, self).__init__(model=model,
35-
target_pars=target_pars,
36-
target_vars=target_vars,
37-
fixed_vars=fixed_vars,
38-
pars_update=pars_update,
39-
resolutions=resolutions,
40-
options=options)
34+
super().__init__(model=model,
35+
target_pars=target_pars,
36+
target_vars=target_vars,
37+
fixed_vars=fixed_vars,
38+
pars_update=pars_update,
39+
resolutions=resolutions,
40+
options=options)
4141

4242
if len(self.target_pars) == 0:
4343
raise ValueError
@@ -146,13 +146,13 @@ class Bifurcation2D(Num2DAnalyzer):
146146

147147
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
148148
pars_update=None, resolutions=None, options=None):
149-
super(Bifurcation2D, self).__init__(model=model,
150-
target_pars=target_pars,
151-
target_vars=target_vars,
152-
fixed_vars=fixed_vars,
153-
pars_update=pars_update,
154-
resolutions=resolutions,
155-
options=options)
149+
super().__init__(model=model,
150+
target_pars=target_pars,
151+
target_vars=target_vars,
152+
fixed_vars=fixed_vars,
153+
pars_update=pars_update,
154+
resolutions=resolutions,
155+
options=options)
156156

157157
if len(self.target_pars) == 0:
158158
raise ValueError
@@ -458,13 +458,13 @@ def __init__(
458458
resolutions=None,
459459
options: dict = None
460460
):
461-
super(FastSlow1D, self).__init__(model=model,
462-
target_pars=slow_vars,
463-
target_vars=fast_vars,
464-
fixed_vars=fixed_vars,
465-
pars_update=pars_update,
466-
resolutions=resolutions,
467-
options=options)
461+
super().__init__(model=model,
462+
target_pars=slow_vars,
463+
target_vars=fast_vars,
464+
fixed_vars=fixed_vars,
465+
pars_update=pars_update,
466+
resolutions=resolutions,
467+
options=options)
468468

469469
# standard integrators
470470
self._std_integrators = dict()
@@ -549,13 +549,13 @@ def __init__(
549549
resolutions=0.1,
550550
options: dict = None
551551
):
552-
super(FastSlow2D, self).__init__(model=model,
553-
target_pars=slow_vars,
554-
target_vars=fast_vars,
555-
fixed_vars=fixed_vars,
556-
pars_update=pars_update,
557-
resolutions=resolutions,
558-
options=options)
552+
super().__init__(model=model,
553+
target_pars=slow_vars,
554+
target_vars=fast_vars,
555+
fixed_vars=fixed_vars,
556+
pars_update=pars_update,
557+
resolutions=resolutions,
558+
options=options)
559559
# standard integrators
560560
self._std_integrators = dict()
561561
for key, intg in self.model.name2integral.items():

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def __init__(self,
5555
if (target_pars is not None) and len(target_pars) > 0:
5656
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
5757
f'While we detect "target_pars={target_pars}".')
58-
super(PhasePlane1D, self).__init__(model=model,
59-
target_vars=target_vars,
60-
fixed_vars=fixed_vars,
61-
target_pars=target_pars,
62-
pars_update=pars_update,
63-
resolutions=resolutions,
64-
**kwargs)
58+
super().__init__(model=model,
59+
target_vars=target_vars,
60+
fixed_vars=fixed_vars,
61+
target_pars=target_pars,
62+
pars_update=pars_update,
63+
resolutions=resolutions,
64+
**kwargs)
6565
# utils.output(f'I am {PhasePlane1D.__name__}.')
6666

6767
def plot_vector_field(self, show=False, with_plot=True, with_return=False):
@@ -150,13 +150,13 @@ def __init__(self,
150150
if (target_pars is not None) and len(target_pars) > 0:
151151
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
152152
f'While we detect "target_pars={target_pars}".')
153-
super(PhasePlane2D, self).__init__(model=model,
154-
target_vars=target_vars,
155-
fixed_vars=fixed_vars,
156-
target_pars=target_pars,
157-
pars_update=pars_update,
158-
resolutions=resolutions,
159-
**kwargs)
153+
super().__init__(model=model,
154+
target_vars=target_vars,
155+
fixed_vars=fixed_vars,
156+
target_pars=target_pars,
157+
pars_update=pars_update,
158+
resolutions=resolutions,
159+
**kwargs)
160160

161161
@property
162162
def F_vmap_brentq_fy(self):
@@ -251,7 +251,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
251251
if with_plot:
252252
if x_style is None:
253253
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
254-
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
254+
line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple()
255255
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")
256256

257257
# Nullcline of the y variable
@@ -263,7 +263,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
263263
if with_plot:
264264
if y_style is None:
265265
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
266-
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
266+
line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple()
267267
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")
268268

269269
if with_plot:

brainpy/_src/analysis/utils/model.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from brainpy._src.math.environment import get_float
66
from brainpy._src.math.interoperability import as_jax
77
from brainpy._src.dynsys import DynamicalSystem
8+
from brainpy._src.context import share
89
from brainpy._src.runners import DSRunner
910
from brainpy._src.integrators.base import Integrator
1011
from brainpy._src.integrators.joint_eq import JointEq
@@ -126,16 +127,12 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
126127
self.integrals = integrals
127128

128129
# runner
129-
self.runner = DSRunner(self,
130-
monitors=list(initial_vars.keys()),
131-
dyn_vars=self.vars().unique(),
132-
dt=dt,
133-
progress_bar=False)
130+
self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)
134131

135-
def update(self, sha):
132+
def update(self):
136133
all_vars = list(self.implicit_vars.values())
137134
for key, intg in self.integrals.items():
138-
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=sha['dt']))
135+
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
139136

140137
def __getattr__(self, item):
141138
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')

0 commit comments

Comments
 (0)