-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathtest_ml.py
More file actions
126 lines (89 loc) · 3.83 KB
/
test_ml.py
File metadata and controls
126 lines (89 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from functools import partial
import matplotlib.pyplot as plt
from jax import grad, jit
from jax import numpy as np
from jax import vmap
from pysages.approxfun import compute_mesh
from pysages.approxfun import scale as _scale
from pysages.grids import Chebyshev, Grid
from pysages.ml.models import MLP, Siren
from pysages.ml.objectives import L2Regularization, Sobolev1SSE
from pysages.ml.optimizers import JaxOptimizer, LevenbergMarquardt
from pysages.ml.training import build_fitting_function
from pysages.ml.utils import pack, unpack
from pysages.utils import try_import
jopt = try_import("jax.example_libraries.optimizers", "jax.experimental.optimizers")
# Test functions
def gaussian(a, mu, sigma, x):
return a * np.exp(-((x - mu) ** 2) / sigma)
def g(x):
return gaussian(1.5, 0.2, 0.03, x) + gaussian(0.5, -0.5, 0.05, x) + gaussian(1.25, 0.9, 1.5, x)
def f(x):
return 0.35 * np.cos(5 * x) + 0.7 * np.sin(-2 * x)
def nngrad(model, params):
d = grad(lambda x: model.apply(params, x.reshape(-1, 1)).sum())
return jit(vmap(d))
def test_siren_sobolev_training():
grid = Grid(lower=(-np.pi,), upper=(np.pi,), shape=(64,), periodic=True)
scale = partial(_scale, grid=grid)
x_scaled = compute_mesh(grid)
x = np.pi * x_scaled
# Periodic function and its gradient
y = vmap(f)(x.flatten()).reshape(x.shape)
dy = vmap(grad(f))(x.flatten()).reshape(x.shape)
topology = (4, 4)
model = Siren(1, 1, topology, transform=scale)
optimizer = LevenbergMarquardt(loss=Sobolev1SSE(), max_iters=200)
fit = build_fitting_function(model, optimizer)
ps, layout = unpack(model.parameters)
params = fit(ps, x, (y, dy)).params
params = jit(lambda ps: pack(ps, layout))(params)
assert np.linalg.norm(y - model.apply(params, x)).item() / x.size < 5e-5
assert np.linalg.norm(dy - nngrad(model, params)(x)).item() / x.size < 5e-4
x_plot = np.linspace(-np.pi, np.pi, 512)
fig, ax = plt.subplots()
ax.plot(x_plot, vmap(f)(x_plot))
ax.plot(x_plot, model.apply(params, x_plot), linestyle="dashed")
fig.savefig("y_periodic_sirens_sobolev_fit.pdf")
plt.close(fig)
fig, ax = plt.subplots()
ax.plot(x_plot, vmap(grad(f))(x_plot))
ax.plot(x_plot, nngrad(model, params)(x_plot), linestyle="dashed")
fig.savefig("dy_periodic_sirens_sobolev_fit.pdf")
plt.close(fig)
def test_mlp_training():
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(64,))
x = compute_mesh(grid)
y = vmap(g)(x.flatten()).reshape(x.shape)
topology = (4, 4)
model = MLP(1, 1, topology)
optimizer = LevenbergMarquardt(reg=L2Regularization(0.0))
fit = build_fitting_function(model, optimizer)
params, layout = unpack(model.parameters)
params = fit(params, x, y).params
y_model = model.apply(pack(params, layout), x)
assert np.linalg.norm(y - y_model).item() / x.size < 5e-4
x_plot = np.linspace(-1, 1, 512)
fig, ax = plt.subplots()
ax.plot(x_plot, vmap(g)(x_plot))
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
fig.savefig("y_mlp_fit.pdf")
plt.close(fig)
def test_adam_optimizer():
grid = Grid[Chebyshev](lower=(-1.0,), upper=(1.0,), shape=(128,))
x = compute_mesh(grid)
y = vmap(g)(x.flatten()).reshape(x.shape)
topology = (4, 4)
model = MLP(1, 1, topology)
optimizer = JaxOptimizer(jopt.adam, tol=1e-6)
fit = build_fitting_function(model, optimizer)
params, layout = unpack(model.parameters)
params = fit(params, x, y).params
y_model = model.apply(pack(params, layout), x)
assert np.linalg.norm(y - y_model).item() / x.size < 1e-2
x_plot = np.linspace(-1, 1, 512)
fig, ax = plt.subplots()
ax.plot(x_plot, vmap(g)(x_plot))
ax.plot(x_plot, model.apply(pack(params, layout), x_plot), linestyle="dashed")
fig.savefig("y_mlp_adam_fit.pdf")
plt.close(fig)