66
77import jax .numpy as jnp
88import numpy as np
9- from jax import vmap
9+ import jax
1010from jax .scipy .optimize import minimize
1111from jax .tree_util import tree_flatten , tree_map
1212
1313import brainpy ._src .math as bm
14- from brainpy import optimizers as optim , losses
14+ from brainpy import optim , losses
1515from brainpy ._src .analysis import utils , base , constants
1616from brainpy ._src .dyn .base import DynamicalSystem
1717from brainpy ._src .dyn .runners import check_and_format_inputs , _f_ops
@@ -132,11 +132,11 @@ def __init__(
132132
133133 # update function
134134 if target_vars is None :
135- self .target_vars = bm .DynVarCollector ()
135+ self .target_vars = bm .ArrayCollector ()
136136 else :
137137 if not isinstance (target_vars , dict ):
138138 raise TypeError (f'"target_vars" must be a dict but we got { type (target_vars )} ' )
139- self .target_vars = bm .DynVarCollector (target_vars )
139+ self .target_vars = bm .ArrayCollector (target_vars )
140140 excluded_vars = () if excluded_vars is None else excluded_vars
141141 if isinstance (excluded_vars , dict ):
142142 excluded_vars = tuple (excluded_vars .values ())
@@ -337,7 +337,7 @@ def find_fps_with_gd_method(
337337 f_eval_loss = self ._get_f_eval_loss ()
338338
339339 def f_loss ():
340- return f_eval_loss (tree_map (lambda a : bm .as_device_array (a ),
340+ return f_eval_loss (tree_map (lambda a : bm .as_jax (a ),
341341 fixed_points ,
342342 is_leaf = lambda x : isinstance (x , bm .Array ))).mean ()
343343
@@ -383,10 +383,10 @@ def batch_train(start_i, n_batch):
383383 f'is below tolerance { tolerance :0.10f} .' )
384384
385385 self ._opt_losses = jnp .concatenate (opt_losses )
386- self ._losses = f_eval_loss (tree_map (lambda a : bm .as_device_array (a ),
386+ self ._losses = f_eval_loss (tree_map (lambda a : bm .as_jax (a ),
387387 fixed_points ,
388388 is_leaf = lambda x : isinstance (x , bm .Array )))
389- self ._fixed_points = tree_map (lambda a : bm .as_device_array (a ),
389+ self ._fixed_points = tree_map (lambda a : bm .as_jax (a ),
390390 fixed_points ,
391391 is_leaf = lambda x : isinstance (x , bm .Array ))
392392 self ._selected_ids = jnp .arange (num_candidate )
@@ -424,9 +424,7 @@ def find_fps_with_opt_solver(
424424 print (f"Optimizing with { opt_solver } to find fixed points:" )
425425
426426 # optimizing
427- res = f_opt (tree_map (lambda a : bm .as_device_array (a ),
428- candidates ,
429- is_leaf = lambda a : isinstance (a , bm .Array )))
427+ res = f_opt (tree_map (lambda a : bm .as_jax (a ), candidates , is_leaf = lambda a : isinstance (a , bm .Array )))
430428
431429 # results
432430 valid_ids = jnp .where (res .success )[0 ]
@@ -666,12 +664,12 @@ def _get_f_eval_loss(self, ):
666664 def _generate_f_eval_loss (self ):
667665 # evaluate losses of a batch of inputs
668666 if self .f_type == constants .DISCRETE :
669- f_eval_loss = lambda h : self .f_loss (h , vmap (self .f_cell )(h ), axis = 1 )
667+ f_eval_loss = lambda h : self .f_loss (h , jax . vmap (self .f_cell )(h ), axis = 1 )
670668 else :
671- f_eval_loss = lambda h : self .f_loss (vmap (self .f_cell )(h ), axis = 1 )
669+ f_eval_loss = lambda h : self .f_loss (jax . vmap (self .f_cell )(h ), axis = 1 )
672670
673671 if isinstance (self .target , DynamicalSystem ):
674- @bm .jit
672+ @jax .jit
675673 def loss_func (h ):
676674 r = f_eval_loss (h )
677675 for k , v in self .excluded_vars .items ():
@@ -682,7 +680,7 @@ def loss_func(h):
682680
683681 return loss_func
684682 else :
685- return bm .jit (f_eval_loss )
683+ return jax .jit (f_eval_loss )
686684
687685 def _get_f_for_opt_solver (self , candidates , opt_method ):
688686 # loss function
@@ -697,17 +695,17 @@ def _get_f_for_opt_solver(self, candidates, opt_method):
697695
698696 def f_loss (h ):
699697 h = {key : h [indices [i ]: indices [i + 1 ]] for i , key in enumerate (keys )}
700- return bm .as_device_array (self .f_loss (h , self .f_cell (h )))
698+ return bm .as_jax (self .f_loss (h , self .f_cell (h )))
701699 else :
702700 def f_loss (h ):
703- return bm .as_device_array (self .f_loss (h , self .f_cell (h )))
701+ return bm .as_jax (self .f_loss (h , self .f_cell (h )))
704702 else :
705703 # overall loss function for fixed points optimization
706704 def f_loss (h ):
707705 return self .f_loss (self .f_cell (h ))
708706
709- @bm .jit
710- @vmap
707+ @jax .jit
708+ @jax . vmap
711709 def f_opt (x0 ):
712710 for k , v in self .target_vars .items ():
713711 v .value = x0 [k ] if (v .batch_axis is None ) else jnp .expand_dims (x0 [k ], axis = v .batch_axis )
@@ -785,7 +783,7 @@ def jacob(x0):
785783 else :
786784 jacob = self .f_cell
787785
788- f_jac = bm .jit (vmap (bm .jacobian (jacob )))
786+ f_jac = jax .jit (jax . vmap (bm .jacobian (jacob )))
789787
790788 if isinstance (self .target , DynamicalSystem ):
791789 def jacobian_func (x ):
0 commit comments