Skip to content

Commit b9da040

Browse files
committed
changes: remove brainpy.math.vmap
1 parent 079eb8f commit b9da040

File tree

11 files changed

+63
-55
lines changed

11 files changed

+63
-55
lines changed

brainpy/analysis/highdim/slow_points.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
from functools import partial
66

7+
from jax import vmap
78
import jax.numpy
89
import numpy as np
910
from jax.scipy.optimize import minimize
@@ -56,15 +57,15 @@ def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True)
5657
if f_loss_batch is None:
5758
if f_type == 'discrete':
5859
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2))
59-
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
60+
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1))
6061
if f_type == 'continuous':
6162
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
62-
self.f_loss_batch = bm.jit(lambda h: bm.mean((bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
63+
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1))
6364

6465
else:
6566
self.f_loss_batch = f_loss_batch
6667
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
67-
self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))
68+
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell)))
6869

6970
# essential variables
7071
self._losses = None
@@ -208,7 +209,7 @@ def find_fps_with_opt_solver(self, candidates, opt_method=None):
208209
opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
209210
if self.verbose:
210211
print(f"Optimizing to find fixed points:")
211-
f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
212+
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0)))
212213
res = f_opt(bm.as_device_array(candidates))
213214
valid_ids = jax.numpy.where(res.success)[0]
214215
self._fixed_points = np.asarray(res.x[valid_ids])

brainpy/analysis/lowdim/lowdim_analyzer.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44

55
import numpy as np
6+
from jax import vmap
67
from jax import numpy as jnp
78
from jax.scipy.optimize import minimize
89

@@ -262,7 +263,7 @@ def F_fx(self):
262263
@property
263264
def F_vmap_fx(self):
264265
if C.F_vmap_fx not in self.analyzed_results:
265-
self.analyzed_results[C.F_vmap_fx] = bm.jit(bm.vmap(self.F_fx), device=self.jit_device)
266+
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
266267
return self.analyzed_results[C.F_vmap_fx]
267268

268269
@property
@@ -289,7 +290,7 @@ def F_vmap_fp_aux(self):
289290
# ---
290291
# "X": a two-dimensional matrix: (num_batch, num_var)
291292
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
292-
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(bm.vmap(self.F_fixed_point_aux))
293+
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
293294
return self.analyzed_results[C.F_vmap_fp_aux]
294295

295296
@property
@@ -308,7 +309,7 @@ def F_vmap_fp_opt(self):
308309
# ---
309310
# "X": a two-dimensional matrix: (num_batch, num_var)
310311
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
311-
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(bm.vmap(self.F_fixed_point_opt))
312+
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
312313
return self.analyzed_results[C.F_vmap_fp_opt]
313314

314315
def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
@@ -501,7 +502,7 @@ def F_y_by_x_in_fy(self):
501502
@property
502503
def F_vmap_fy(self):
503504
if C.F_vmap_fy not in self.analyzed_results:
504-
self.analyzed_results[C.F_vmap_fy] = bm.jit(bm.vmap(self.F_fy), device=self.jit_device)
505+
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
505506
return self.analyzed_results[C.F_vmap_fy]
506507

507508
@property
@@ -663,7 +664,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
663664

664665
if self.F_x_by_y_in_fx is not None:
665666
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
666-
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fx), device=self.jit_device)
667+
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
667668
for j, pars in enumerate(par_seg):
668669
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
669670
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -679,7 +680,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
679680

680681
elif self.F_y_by_x_in_fx is not None:
681682
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
682-
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fx), device=self.jit_device)
683+
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
683684
for j, pars in enumerate(par_seg):
684685
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
685686
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -697,9 +698,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
697698
utils.output("I am evaluating fx-nullcline by optimization ...")
698699
# auxiliary functions
699700
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
700-
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
701-
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
702-
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
701+
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
702+
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
703+
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
703704

704705
# num segments
705706
for _j, Ps in enumerate(par_seg):
@@ -756,7 +757,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
756757

757758
if self.F_x_by_y_in_fy is not None:
758759
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
759-
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fy), device=self.jit_device)
760+
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
760761
for j, pars in enumerate(par_seg):
761762
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
762763
mesh_values = jnp.meshgrid(*((ys,) + pars))
@@ -772,7 +773,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
772773

773774
elif self.F_y_by_x_in_fy is not None:
774775
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
775-
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fy), device=self.jit_device)
776+
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
776777
for j, pars in enumerate(par_seg):
777778
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
778779
mesh_values = jnp.meshgrid(*((xs,) + pars))
@@ -791,9 +792,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
791792

792793
# auxiliary functions
793794
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
794-
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
795-
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
796-
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
795+
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
796+
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
797+
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
797798

798799
for j, Ps in enumerate(par_seg):
799800
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
@@ -841,7 +842,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
841842
xs = self.resolutions[self.x_var].value
842843
ys = self.resolutions[self.y_var].value
843844
P = tuple(self.resolutions[p].value for p in self.target_par_names)
844-
f_select = bm.jit(bm.vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
845+
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
845846

846847
# num seguments
847848
if isinstance(num_segments, int):
@@ -921,10 +922,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
921922

922923
if self.convert_type() == C.x_by_y:
923924
num_seg = len(self.resolutions[self.y_var])
924-
f_vmap = bm.jit(bm.vmap(self.F_y_convert[1]))
925+
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
925926
else:
926927
num_seg = len(self.resolutions[self.x_var])
927-
f_vmap = bm.jit(bm.vmap(self.F_x_convert[1]))
928+
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
928929
# get the signs
929930
signs = jnp.sign(f_vmap(candidates, *args))
930931
signs = signs.reshape((num_seg, -1))
@@ -954,10 +955,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
954955
# get another value
955956
if self.convert_type() == C.x_by_y:
956957
y_values = fps
957-
x_values = bm.jit(bm.vmap(self.F_y_convert[0]))(y_values, *args)
958+
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
958959
else:
959960
x_values = fps
960-
y_values = bm.jit(bm.vmap(self.F_x_convert[0]))(x_values, *args)
961+
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
961962
fps = jnp.stack([x_values, y_values]).T
962963
return fps, selected_ids, args
963964

brainpy/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44

55
import jax.numpy as jnp
6+
from jax import vmap
67
import numpy as np
78

89
import brainpy.math as bm
@@ -42,7 +43,7 @@ def __init__(self, model, target_pars, target_vars, fixed_vars=None,
4243
@property
4344
def F_vmap_dfxdx(self):
4445
if C.F_vmap_dfxdx not in self.analyzed_results:
45-
f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
46+
f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
4647
self.analyzed_results[C.F_vmap_dfxdx] = f
4748
return self.analyzed_results[C.F_vmap_dfxdx]
4849

@@ -159,7 +160,7 @@ def F_vmap_jacobian(self):
159160
if C.F_vmap_jacobian not in self.analyzed_results:
160161
f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
161162
self.F_fy(xy[0], xy[1], *args)])
162-
f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device)
163+
f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device)
163164
self.analyzed_results[C.F_vmap_jacobian] = f2
164165
return self.analyzed_results[C.F_vmap_jacobian]
165166

brainpy/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import jax.numpy as jnp
44
import numpy as np
5+
from jax import vmap
56

67
import brainpy.math as bm
78
from brainpy import errors, math
@@ -158,7 +159,7 @@ def __init__(self,
158159
@property
159160
def F_vmap_brentq_fy(self):
160161
if C.F_vmap_brentq_fy not in self.analyzed_results:
161-
f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)))
162+
f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy)))
162163
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
163164
return self.analyzed_results[C.F_vmap_brentq_fy]
164165

brainpy/analysis/utils/optimization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax.lax
55
import jax.numpy as jnp
66
import numpy as np
7-
from jax import grad, jit
7+
from jax import grad, jit, vmap
88
from jax.flatten_util import ravel_pytree
99

1010
import brainpy.math as bm
@@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()):
197197

198198
def brentq_roots(f, starts, ends, *vmap_args, args=()):
199199
in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
200-
vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes))
200+
vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes))
201201
all_args = vmap_args + args
202202
if len(all_args):
203203
res = vmap_f_opt(starts, ends, all_args)
@@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
397397
return fps
398398
starts = candidates[candidate_ids]
399399
ends = candidates[candidate_ids + 1]
400-
f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None)))
400+
f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None)))
401401
res = f_opt(starts, ends, args)
402402
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
403403
fps2 = res['root'][valid_idx]
@@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()):
406406

407407
def roots_of_1d_by_xy(f, starts, ends, args):
408408
f = f_without_jaxarray_return(f)
409-
f_opt = bm.jit(bm.vmap(jax_brentq(f)))
409+
f_opt = bm.jit(vmap(jax_brentq(f)))
410410
res = f_opt(starts, ends, (args,))
411411
valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
412412
xs = res['root'][valid_idx]

brainpy/analysis/utils/others.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import jax.numpy as jnp
4+
from jax import vmap
45
import numpy as np
56

67
import brainpy.math as bm
@@ -76,7 +77,7 @@ def get_sign(f, xs, ys):
7677

7778
def get_sign2(f, *xyz, args=()):
7879
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
79-
f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
80+
f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes))
8081
xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
8182
XYZ = jnp.meshgrid(*xyz)
8283
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)

brainpy/dyn/synapses/delay_coupling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def update(self, _t, _dt):
193193
variable = getattr(self.pre, var)
194194

195195
# delay function
196-
f = bm.vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,)
196+
f = vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,)
197197
delays = f(bm.arange(self.post.num)) # (post.num, pre.num)
198198
additive = (self.conn_mat * delays).sum(axis=1)
199199

brainpy/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def __init__(self, variables=None):
101101
else:
102102
raise ValueError
103103

104-
msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'
104+
# msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'
105+
msg = 'While there are changed variables which are not wrapped into "dyn_vars". Please check!'
105106

106107
super(JaxTracerError, self).__init__(msg)
107108

brainpy/math/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from .autograd import *
4747
from .controls import *
4848
from .jit import *
49-
from .parallels import *
49+
# from .parallels import *
5050

5151
# settings
5252
from . import setting

brainpy/math/parallels.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,31 @@
3636
]
3737

3838

39-
def _make_vmap(func, dyn_vars, rand_vars, in_axes, out_axes,
40-
batch_idx, axis_name, reduce_func, f_name=None):
39+
def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes,
40+
batch_idx, axis_name, f_name=None):
4141
@functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name)
42-
def vmapped_func(dyn_data, rand_data, *args, **kwargs):
43-
dyn_vars.assign(dyn_data)
44-
rand_vars.assign(rand_data)
42+
def vmapped_func(nonbatched_data, batched_data, *args, **kwargs):
43+
nonbatched_vars.assign(nonbatched_data)
44+
batched_vars.assign(batched_data)
4545
out = func(*args, **kwargs)
46-
dyn_changes = dyn_vars.dict()
47-
rand_changes = rand_vars.dict()
48-
return out, dyn_changes, rand_changes
46+
nonbatched_changes = nonbatched_vars.dict()
47+
batched_changes = batched_vars.dict()
48+
return nonbatched_changes, batched_changes, out
4949

5050
def call(*args, **kwargs):
51-
dyn_data = dyn_vars.dict()
5251
n = args[batch_idx[0]].shape[batch_idx[1]]
53-
rand_data = {key: val.split_keys(n) for key, val in rand_vars.items()}
52+
nonbatched_data = nonbatched_vars.dict()
53+
batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()}
5454
try:
55-
out, dyn_changes, rand_changes = vmapped_func(dyn_data, rand_data, *args, **kwargs)
55+
out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs)
5656
except UnexpectedTracerError as e:
57-
dyn_vars.assign(dyn_data)
58-
rand_vars.assign(rand_data)
59-
raise errors.JaxTracerError(variables=dyn_vars) from e
60-
for key, v in dyn_changes.items(): dyn_vars[key] = reduce_func(v)
61-
for key, v in rand_changes.items(): rand_vars[key] = reduce_func(v)
57+
nonbatched_vars.assign(nonbatched_data)
58+
batched_vars.assign(batched_data)
59+
raise errors.JaxTracerError() from e
60+
# for key, v in dyn_changes.items():
61+
# dyn_vars[key] = reduce_func(v)
62+
# for key, v in rand_changes.items():
63+
# rand_vars[key] = reduce_func(v)
6264
return out
6365

6466
return change_func_name(name=f_name, f=call) if f_name else call
@@ -256,13 +258,12 @@ def vmap(func, dyn_vars=None, batched_vars=None,
256258

257259
# jit function
258260
return _make_vmap(func=func,
259-
dyn_vars=_dyn_vars,
260-
rand_vars=_rand_vars,
261+
nonbatched_vars=_dyn_vars,
262+
batched_vars=_rand_vars,
261263
in_axes=in_axes,
262264
out_axes=out_axes,
263265
axis_name=axis_name,
264-
batch_idx=batch_idx,
265-
reduce_func=reduce_func)
266+
batch_idx=batch_idx)
266267

267268
else:
268269
raise errors.BrainPyError(f'Only support instance of {Base.__name__}, or a callable '

0 commit comments

Comments
 (0)