Skip to content

Commit 3a41997

Browse files
authored
Improve support for jax.example_libraries.optimizers (#350)
2 parents 3089166 + d1de535 commit 3a41997

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

pysages/ml/optimizers.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,18 @@
2222
)
2323
from pysages.ml.utils import dispatch, pack, unpack
2424
from pysages.typing import Any, Callable, JaxArray, NamedTuple, Tuple, Union
25-
from pysages.utils import solve_pos_def, try_import
26-
27-
jopt = try_import("jax.example_libraries.optimizers", "jax.experimental.optimizers")
28-
25+
from pysages.utils import solve_pos_def
2926

3027
# Optimizers parameters
3128

3229

33-
class AdamParams(NamedTuple):
30+
class JaxOptimizerParams(NamedTuple):
3431
"""
35-
Parameters for the ADAM optimizer.
32+
Parameters for the jax.example_libraries optimizers.
3633
"""
3734

38-
step_size: Union[float, Callable] = 1e-2
39-
beta_1: float = 0.9
40-
beta_2: float = 0.999
41-
tol: float = 1e-8
35+
step_size: Union[float, Callable] = 1e-3
36+
kwargs: dict = {}
4237

4338

4439
class LevenbergMarquardtParams(NamedTuple):
@@ -64,6 +59,7 @@ class WrappedState(NamedTuple):
6459
"""
6560

6661
data: Tuple[JaxArray, JaxArray]
62+
state: Any
6763
params: Any
6864
iters: int = 0
6965
improved: bool = True
@@ -105,17 +101,21 @@ class Optimizer:
105101

106102

107103
@dataclass
108-
class Adam(Optimizer):
104+
class JaxOptimizer(Optimizer):
109105
"""
110-
ADAM optimizer from stax.example_libraries.optimizers.
106+
Setup class for stax.example_libraries.optimizers.
111107
"""
112108

113-
params: AdamParams = AdamParams()
109+
constructor: Callable
110+
params: JaxOptimizerParams = JaxOptimizerParams()
114111
loss: Loss = SSE()
115112
reg: Regularizer = L2Regularization(0.0)
116-
tol: float = 1e-4
113+
tol: float = 1e-5
117114
max_iters: int = 10000
118115

116+
def __call__(self):
117+
return self.constructor(self.params.step_size, **self.params.kwargs)
118+
119119

120120
@dataclass
121121
class LevenbergMarquardt(Optimizer):
@@ -155,27 +155,31 @@ def build(optimizer, model): # pylint: disable=W0613
155155

156156

157157
@dispatch
158-
def build(optimizer: Adam, model):
158+
def build(optimizer: JaxOptimizer, model):
159159
# pylint: disable=C0116,E0102
160-
_init, _update, repack = jopt.adam(*optimizer.params)
160+
_init, _update, get_params = optimizer()
161161
objective = build_objective_function(model, optimizer.loss, optimizer.reg)
162162
gradient = jax.grad(objective)
163163
max_iters = optimizer.max_iters
164164
_, layout = unpack(model.parameters)
165165

166+
def flatten(params):
167+
return unpack(params)[0]
168+
166169
def initialize(params, x, y):
167-
wrapped_params = _init(pack(params, layout))
168-
return WrappedState((x, y), wrapped_params)
170+
state = _init(pack(params, layout))
171+
return WrappedState((x, y), state, flatten(get_params(state)))
169172

170173
def keep_iterating(state):
171174
return state.improved & (state.iters < max_iters)
172175

173176
def update(state):
174-
data, params, iters, _ = state
175-
dp = gradient(repack(params), *data)
176-
params = _update(iters, dp, params)
177-
improved = sum_squares(unpack(dp)[0]) > optimizer.tol
178-
return WrappedState(data, params, iters + 1, improved)
177+
data, opt_state, _, iters, _ = state
178+
dp = gradient(get_params(opt_state), *data)
179+
opt_state = _update(iters, dp, opt_state)
180+
new_params = get_params(opt_state)
181+
improved = sum_squares(flatten(dp)) > optimizer.tol
182+
return WrappedState(data, opt_state, flatten(new_params), iters + 1, improved)
179183

180184
return initialize, keep_iterating, update
181185

tests/test_ml.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
from pysages.grids import Chebyshev, Grid
1111
from pysages.ml.models import MLP, Siren
1212
from pysages.ml.objectives import L2Regularization, Sobolev1SSE
13-
from pysages.ml.optimizers import LevenbergMarquardt
13+
from pysages.ml.optimizers import JaxOptimizer, LevenbergMarquardt
1414
from pysages.ml.training import build_fitting_function
1515
from pysages.ml.utils import pack, unpack
16+
from pysages.utils import try_import
17+
18+
jopt = try_import("jax.example_libraries.optimizers", "jax.experimental.optimizers")
1619

1720

1821
# Test functions
@@ -95,3 +98,29 @@ def test_mlp_training():
9598
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
9699
fig.savefig("y_mlp_fit.pdf")
97100
plt.close(fig)
101+
102+
103+
def test_adam_optimizer():
104+
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(128,))
105+
106+
x = compute_mesh(grid)
107+
108+
y = vmap(g)(x.flatten()).reshape(x.shape)
109+
110+
topology = (4, 4)
111+
model = MLP(1, 1, topology)
112+
optimizer = JaxOptimizer(jopt.adam, tol=1e-6)
113+
fit = build_fitting_function(model, optimizer)
114+
115+
params, layout = unpack(model.parameters)
116+
params = fit(params, x, y).params
117+
y_model = model.apply(pack(params, layout), x)
118+
119+
assert np.linalg.norm(y - y_model).item() / x.size < 1e-2
120+
121+
x_plot = np.linspace(-1, 1, 512)
122+
fig, ax = plt.subplots()
123+
ax.plot(x_plot, vmap(g)(x_plot))
124+
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
125+
fig.savefig("y_mlp_adam_fit.pdf")
126+
plt.close(fig)

0 commit comments

Comments
 (0)