2222)
2323from pysages .ml .utils import dispatch , pack , unpack
2424from pysages .typing import Any , Callable , JaxArray , NamedTuple , Tuple , Union
25- from pysages .utils import solve_pos_def , try_import
26-
27- jopt = try_import ("jax.example_libraries.optimizers" , "jax.experimental.optimizers" )
28-
25+ from pysages .utils import solve_pos_def
2926
3027# Optimizers parameters
3128
3229
33- class AdamParams (NamedTuple ):
30+ class JaxOptimizerParams (NamedTuple ):
3431 """
35- Parameters for the ADAM optimizer .
32+ Parameters for the jax.example_libraries optimizers .
3633 """
3734
38- step_size : Union [float , Callable ] = 1e-2
39- beta_1 : float = 0.9
40- beta_2 : float = 0.999
41- tol : float = 1e-8
35+ step_size : Union [float , Callable ] = 1e-3
36+ kwargs : dict = {}
4237
4338
4439class LevenbergMarquardtParams (NamedTuple ):
@@ -64,6 +59,7 @@ class WrappedState(NamedTuple):
6459 """
6560
6661 data : Tuple [JaxArray , JaxArray ]
62+ state : Any
6763 params : Any
6864 iters : int = 0
6965 improved : bool = True
@@ -105,17 +101,21 @@ class Optimizer:
105101
106102
107103@dataclass
108- class Adam (Optimizer ):
104+ class JaxOptimizer (Optimizer ):
109105 """
110- ADAM optimizer from stax.example_libraries.optimizers.
106+ Setup class for stax.example_libraries.optimizers.
111107 """
112108
113- params : AdamParams = AdamParams ()
109+ constructor : Callable
110+ params : JaxOptimizerParams = JaxOptimizerParams ()
114111 loss : Loss = SSE ()
115112 reg : Regularizer = L2Regularization (0.0 )
116- tol : float = 1e-4
113+ tol : float = 1e-5
117114 max_iters : int = 10000
118115
116+ def __call__ (self ):
117+ return self .constructor (self .params .step_size , ** self .params .kwargs )
118+
119119
120120@dataclass
121121class LevenbergMarquardt (Optimizer ):
@@ -155,27 +155,31 @@ def build(optimizer, model): # pylint: disable=W0613
155155
156156
157157@dispatch
158- def build (optimizer : Adam , model ):
158+ def build (optimizer : JaxOptimizer , model ):
159159 # pylint: disable=C0116,E0102
160- _init , _update , repack = jopt . adam ( * optimizer . params )
160+ _init , _update , get_params = optimizer ( )
161161 objective = build_objective_function (model , optimizer .loss , optimizer .reg )
162162 gradient = jax .grad (objective )
163163 max_iters = optimizer .max_iters
164164 _ , layout = unpack (model .parameters )
165165
166+ def flatten (params ):
167+ return unpack (params )[0 ]
168+
166169 def initialize (params , x , y ):
167- wrapped_params = _init (pack (params , layout ))
168- return WrappedState ((x , y ), wrapped_params )
170+ state = _init (pack (params , layout ))
171+ return WrappedState ((x , y ), state , flatten ( get_params ( state )) )
169172
170173 def keep_iterating (state ):
171174 return state .improved & (state .iters < max_iters )
172175
173176 def update (state ):
174- data , params , iters , _ = state
175- dp = gradient (repack (params ), * data )
176- params = _update (iters , dp , params )
177- improved = sum_squares (unpack (dp )[0 ]) > optimizer .tol
178- return WrappedState (data , params , iters + 1 , improved )
177+ data , opt_state , _ , iters , _ = state
178+ dp = gradient (get_params (opt_state ), * data )
179+ opt_state = _update (iters , dp , opt_state )
180+ new_params = get_params (opt_state )
181+ improved = sum_squares (flatten (dp )) > optimizer .tol
182+ return WrappedState (data , opt_state , flatten (new_params ), iters + 1 , improved )
179183
180184 return initialize , keep_iterating , update
181185
0 commit comments