Skip to content

Commit df1897b

Browse files
authored
Merge pull request #358 from chaoming0625/master
Automatic transformations any function/object using `brainpy.math.Variable`
2 parents 59bfc0d + 542fea4 commit df1897b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1975
-1605
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,4 @@ cython_debug/
223223
/examples/training_snn_models/logs/
224224
/examples/training_snn_models/data/
225225
/docs/tutorial_advanced/data/
226+
/my_tests/

brainpy/__init__.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,9 @@
132132

133133
# deprecated
134134
from brainpy._src.math.object_transform.base import (Base as Base,
135-
DynVarCollector,
135+
ArrayCollector,
136136
Collector as Collector, )
137-
globals()['ArrayCollector'] = DynVarCollector
138-
globals()['TensorCollector'] = DynVarCollector
137+
globals()['TensorCollector'] = ArrayCollector
139138

140139
train.__dict__['DSTrainer'] = DSTrainer
141140
train.__dict__['BPTT'] = BPTT
@@ -150,8 +149,8 @@
150149
base.base.__dict__['BrainPyObject'] = BrainPyObject
151150
base.base.__dict__['Base'] = Base
152151
base.collector.__dict__['Collector'] = Collector
153-
base.collector.__dict__['ArrayCollector'] = DynVarCollector
154-
base.collector.__dict__['TensorCollector'] = DynVarCollector
152+
base.collector.__dict__['ArrayCollector'] = ArrayCollector
153+
base.collector.__dict__['TensorCollector'] = ArrayCollector
155154
base.function.__dict__['FunAsObject'] = math.FunAsObject
156155
base.function.__dict__['Function'] = math.FunAsObject
157156
base.io.__dict__['save_as_h5'] = checkpoints.io.save_as_h5
@@ -162,14 +161,12 @@
162161
base.io.__dict__['load_by_npz'] = checkpoints.io.load_by_npz
163162
base.io.__dict__['load_by_pkl'] = checkpoints.io.load_by_pkl
164163
base.io.__dict__['load_by_mat'] = checkpoints.io.load_by_mat
165-
base.naming.__dict__['check_name_uniqueness'] = tools.check_name_uniqueness
166-
base.naming.__dict__['clear_name_cache'] = tools.clear_name_cache
167-
base.naming.__dict__['get_unique_name'] = tools.get_unique_name
164+
base.naming.__dict__['clear_name_cache'] = math.clear_name_cache
168165
base.__dict__['BrainPyObject'] = BrainPyObject
169166
base.__dict__['Base'] = Base
170167
base.__dict__['Collector'] = Collector
171-
base.__dict__['ArrayCollector'] = DynVarCollector
172-
base.__dict__['TensorCollector'] = DynVarCollector
168+
base.__dict__['ArrayCollector'] = ArrayCollector
169+
base.__dict__['TensorCollector'] = ArrayCollector
173170
base.__dict__['FunAsObject'] = math.FunAsObject
174171
base.__dict__['Function'] = math.FunAsObject
175172
base.__dict__['save_as_h5'] = checkpoints.io.save_as_h5
@@ -180,9 +177,7 @@
180177
base.__dict__['load_by_npz'] = checkpoints.io.load_by_npz
181178
base.__dict__['load_by_pkl'] = checkpoints.io.load_by_pkl
182179
base.__dict__['load_by_mat'] = checkpoints.io.load_by_mat
183-
base.__dict__['check_name_uniqueness'] = tools.check_name_uniqueness
184-
base.__dict__['clear_name_cache'] = tools.clear_name_cache
185-
base.__dict__['get_unique_name'] = tools.get_unique_name
180+
base.__dict__['clear_name_cache'] = math.clear_name_cache
186181

187182

188183
from . import modes
@@ -252,4 +247,5 @@
252247

253248
from brainpy._src import checking
254249
tools.__dict__['checking'] = checking
250+
tools.__dict__['clear_name_cache'] = math.clear_name_cache
255251
del checking

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
import jax.numpy as jnp
88
import numpy as np
9-
from jax import vmap
9+
import jax
1010
from jax.scipy.optimize import minimize
1111
from jax.tree_util import tree_flatten, tree_map
1212

1313
import brainpy._src.math as bm
14-
from brainpy import optimizers as optim, losses
14+
from brainpy import optim, losses
1515
from brainpy._src.analysis import utils, base, constants
1616
from brainpy._src.dyn.base import DynamicalSystem
1717
from brainpy._src.dyn.runners import check_and_format_inputs, _f_ops
@@ -132,11 +132,11 @@ def __init__(
132132

133133
# update function
134134
if target_vars is None:
135-
self.target_vars = bm.DynVarCollector()
135+
self.target_vars = bm.ArrayCollector()
136136
else:
137137
if not isinstance(target_vars, dict):
138138
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
139-
self.target_vars = bm.DynVarCollector(target_vars)
139+
self.target_vars = bm.ArrayCollector(target_vars)
140140
excluded_vars = () if excluded_vars is None else excluded_vars
141141
if isinstance(excluded_vars, dict):
142142
excluded_vars = tuple(excluded_vars.values())
@@ -337,7 +337,7 @@ def find_fps_with_gd_method(
337337
f_eval_loss = self._get_f_eval_loss()
338338

339339
def f_loss():
340-
return f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
340+
return f_eval_loss(tree_map(lambda a: bm.as_jax(a),
341341
fixed_points,
342342
is_leaf=lambda x: isinstance(x, bm.Array))).mean()
343343

@@ -383,10 +383,10 @@ def batch_train(start_i, n_batch):
383383
f'is below tolerance {tolerance:0.10f}.')
384384

385385
self._opt_losses = jnp.concatenate(opt_losses)
386-
self._losses = f_eval_loss(tree_map(lambda a: bm.as_device_array(a),
386+
self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a),
387387
fixed_points,
388388
is_leaf=lambda x: isinstance(x, bm.Array)))
389-
self._fixed_points = tree_map(lambda a: bm.as_device_array(a),
389+
self._fixed_points = tree_map(lambda a: bm.as_jax(a),
390390
fixed_points,
391391
is_leaf=lambda x: isinstance(x, bm.Array))
392392
self._selected_ids = jnp.arange(num_candidate)
@@ -424,9 +424,7 @@ def find_fps_with_opt_solver(
424424
print(f"Optimizing with {opt_solver} to find fixed points:")
425425

426426
# optimizing
427-
res = f_opt(tree_map(lambda a: bm.as_device_array(a),
428-
candidates,
429-
is_leaf=lambda a: isinstance(a, bm.Array)))
427+
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.Array)))
430428

431429
# results
432430
valid_ids = jnp.where(res.success)[0]
@@ -666,12 +664,12 @@ def _get_f_eval_loss(self, ):
666664
def _generate_f_eval_loss(self):
667665
# evaluate losses of a batch of inputs
668666
if self.f_type == constants.DISCRETE:
669-
f_eval_loss = lambda h: self.f_loss(h, vmap(self.f_cell)(h), axis=1)
667+
f_eval_loss = lambda h: self.f_loss(h, jax.vmap(self.f_cell)(h), axis=1)
670668
else:
671-
f_eval_loss = lambda h: self.f_loss(vmap(self.f_cell)(h), axis=1)
669+
f_eval_loss = lambda h: self.f_loss(jax.vmap(self.f_cell)(h), axis=1)
672670

673671
if isinstance(self.target, DynamicalSystem):
674-
@bm.jit
672+
@jax.jit
675673
def loss_func(h):
676674
r = f_eval_loss(h)
677675
for k, v in self.excluded_vars.items():
@@ -682,7 +680,7 @@ def loss_func(h):
682680

683681
return loss_func
684682
else:
685-
return bm.jit(f_eval_loss)
683+
return jax.jit(f_eval_loss)
686684

687685
def _get_f_for_opt_solver(self, candidates, opt_method):
688686
# loss function
@@ -697,17 +695,17 @@ def _get_f_for_opt_solver(self, candidates, opt_method):
697695

698696
def f_loss(h):
699697
h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)}
700-
return bm.as_device_array(self.f_loss(h, self.f_cell(h)))
698+
return bm.as_jax(self.f_loss(h, self.f_cell(h)))
701699
else:
702700
def f_loss(h):
703-
return bm.as_device_array(self.f_loss(h, self.f_cell(h)))
701+
return bm.as_jax(self.f_loss(h, self.f_cell(h)))
704702
else:
705703
# overall loss function for fixed points optimization
706704
def f_loss(h):
707705
return self.f_loss(self.f_cell(h))
708706

709-
@bm.jit
710-
@vmap
707+
@jax.jit
708+
@jax.vmap
711709
def f_opt(x0):
712710
for k, v in self.target_vars.items():
713711
v.value = x0[k] if (v.batch_axis is None) else jnp.expand_dims(x0[k], axis=v.batch_axis)
@@ -785,7 +783,7 @@ def jacob(x0):
785783
else:
786784
jacob = self.f_cell
787785

788-
f_jac = bm.jit(vmap(bm.jacobian(jacob)))
786+
f_jac = jax.jit(jax.vmap(bm.jacobian(jacob)))
789787

790788
if isinstance(self.target, DynamicalSystem):
791789
def jacobian_func(x):

brainpy/_src/analysis/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# -*- coding: utf-8 -*-
22

33

4-
from brainpy._src.math.ndarray import Variable
4+
from brainpy._src.math.object_transform import Variable
55
from brainpy._src.math.environment import get_float
6-
from brainpy._src.math.arrayinterporate import as_jax
6+
from brainpy._src.math.interoperability import as_jax
77
from brainpy._src.dyn.base import DynamicalSystem
88
from brainpy._src.dyn.runners import DSRunner
99
from brainpy._src.integrators.base import Integrator

brainpy/_src/base/naming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
from brainpy._src.tools import naming
3+
from brainpy._src.math.object_transform import naming
44

55
__all__ = [
66
'check_name_uniqueness',

brainpy/_src/checkpoints/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from brainpy import errors
1010
import brainpy.math as bm
11-
from brainpy._src.math.object_transform.base import BrainPyObject, DynVarCollector
11+
from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector
1212

1313

1414
logger = logging.getLogger('brainpy.brainpy_object.io')
@@ -120,7 +120,7 @@ def _load(
120120

121121

122122
def _unique_and_duplicate(collector: dict):
123-
gather = DynVarCollector()
123+
gather = ArrayCollector()
124124
id2name = dict()
125125
duplicates = ([], [])
126126
for k, v in collector.items():

brainpy/_src/dyn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
1515
from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
1616
from brainpy._src.integrators import odeint, sdeint
17-
from brainpy._src.math.ndarray import Variable, VariableView
17+
from brainpy._src.math.object_transform.variables import Variable, VariableView
1818
from brainpy._src.math.object_transform.base import BrainPyObject, Collector
1919
from brainpy.errors import NoImplementationError, UnsupportedError
2020
from brainpy.types import ArrayType, Shape

brainpy/_src/dyn/neurons/biological_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def update(self, x=None):
317317
x = self.input.value
318318
else:
319319
x = 0. if x is None else x
320-
V, m, h, n = self.integral(self.V, self.m, self.h, self.n, t, x, dt)
320+
V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt)
321321
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
322322
self.V.value = V
323323
self.m.value = m

brainpy/_src/dyn/runners.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ class DSRunner(Runner):
301301
In order to be compatible with previous API, default is set to be ``False``.
302302
303303
.. versionadded:: 2.3.1
304+
305+
memory_efficient: bool
306+
Whether using the memory-efficient way to just-in-time compile the given target.
307+
Default is False.
308+
309+
.. versionadded:: 2.3.8
310+
304311
"""
305312

306313
target: DynamicalSystem
@@ -697,3 +704,4 @@ def __del__(self):
697704
for key in tuple(self._f_predict_compiled.keys()):
698705
self._f_predict_compiled.pop(key)
699706
super(DSRunner, self).__del__()
707+

brainpy/_src/integrators/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
# import brainpy.math as bm
4-
from brainpy._src.tools import naming
4+
from brainpy._src.math.object_transform import naming
55

66
__all__ = [
77
'DT',

0 commit comments

Comments
 (0)