Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions pysages/ml/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@
)
from pysages.ml.utils import dispatch, pack, unpack
from pysages.typing import Any, Callable, JaxArray, NamedTuple, Tuple, Union
from pysages.utils import solve_pos_def, try_import

jopt = try_import("jax.example_libraries.optimizers", "jax.experimental.optimizers")

from pysages.utils import solve_pos_def

# Optimizers parameters


class AdamParams(NamedTuple):
class JaxOptimizerParams(NamedTuple):
"""
Parameters for the ADAM optimizer.
Parameters for the jax.example_libraries optimizers.
"""

step_size: Union[float, Callable] = 1e-2
beta_1: float = 0.9
beta_2: float = 0.999
tol: float = 1e-8
step_size: Union[float, Callable] = 1e-3
kwargs: dict = {}


class LevenbergMarquardtParams(NamedTuple):
Expand All @@ -64,6 +59,7 @@ class WrappedState(NamedTuple):
"""

data: Tuple[JaxArray, JaxArray]
state: Any
params: Any
iters: int = 0
improved: bool = True
Expand Down Expand Up @@ -105,17 +101,21 @@ class Optimizer:


@dataclass
class Adam(Optimizer):
class JaxOptimizer(Optimizer):
"""
ADAM optimizer from stax.example_libraries.optimizers.
Setup class for stax.example_libraries.optimizers.
"""

params: AdamParams = AdamParams()
constructor: Callable
params: JaxOptimizerParams = JaxOptimizerParams()
loss: Loss = SSE()
reg: Regularizer = L2Regularization(0.0)
tol: float = 1e-4
tol: float = 1e-5
max_iters: int = 10000

def __call__(self):
return self.constructor(self.params.step_size, **self.params.kwargs)


@dataclass
class LevenbergMarquardt(Optimizer):
Expand Down Expand Up @@ -155,27 +155,31 @@ def build(optimizer, model): # pylint: disable=W0613


@dispatch
def build(optimizer: Adam, model):
def build(optimizer: JaxOptimizer, model):
# pylint: disable=C0116,E0102
_init, _update, repack = jopt.adam(*optimizer.params)
_init, _update, get_params = optimizer()
objective = build_objective_function(model, optimizer.loss, optimizer.reg)
gradient = jax.grad(objective)
max_iters = optimizer.max_iters
_, layout = unpack(model.parameters)

def flatten(params):
return unpack(params)[0]

def initialize(params, x, y):
wrapped_params = _init(pack(params, layout))
return WrappedState((x, y), wrapped_params)
state = _init(pack(params, layout))
return WrappedState((x, y), state, flatten(get_params(state)))

def keep_iterating(state):
return state.improved & (state.iters < max_iters)

def update(state):
data, params, iters, _ = state
dp = gradient(repack(params), *data)
params = _update(iters, dp, params)
improved = sum_squares(unpack(dp)[0]) > optimizer.tol
return WrappedState(data, params, iters + 1, improved)
data, opt_state, _, iters, _ = state
dp = gradient(get_params(opt_state), *data)
opt_state = _update(iters, dp, opt_state)
new_params = get_params(opt_state)
improved = sum_squares(flatten(dp)) > optimizer.tol
return WrappedState(data, opt_state, flatten(new_params), iters + 1, improved)

return initialize, keep_iterating, update

Expand Down
31 changes: 30 additions & 1 deletion tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from pysages.grids import Chebyshev, Grid
from pysages.ml.models import MLP, Siren
from pysages.ml.objectives import L2Regularization, Sobolev1SSE
from pysages.ml.optimizers import LevenbergMarquardt
from pysages.ml.optimizers import JaxOptimizer, LevenbergMarquardt
from pysages.ml.training import build_fitting_function
from pysages.ml.utils import pack, unpack
from pysages.utils import try_import

jopt = try_import("jax.example_libraries.optimizers", "jax.experimental.optimizers")


# Test functions
Expand Down Expand Up @@ -95,3 +98,29 @@ def test_mlp_training():
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
fig.savefig("y_mlp_fit.pdf")
plt.close(fig)


def test_adam_optimizer():
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(128,))

x = compute_mesh(grid)

y = vmap(g)(x.flatten()).reshape(x.shape)

topology = (4, 4)
model = MLP(1, 1, topology)
optimizer = JaxOptimizer(jopt.adam, tol=1e-6)
fit = build_fitting_function(model, optimizer)

params, layout = unpack(model.parameters)
params = fit(params, x, y).params
y_model = model.apply(pack(params, layout), x)

assert np.linalg.norm(y - y_model).item() / x.size < 1e-2

x_plot = np.linspace(-1, 1, 512)
fig, ax = plt.subplots()
ax.plot(x_plot, vmap(g)(x_plot))
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
fig.savefig("y_mlp_adam_fit.pdf")
plt.close(fig)