|
1 | 1 | from jax import config |
2 | 2 |
|
| 3 | +from gsd.gsd import make_softvmin, vmax, vmin |
| 4 | + |
3 | 5 | config.update("jax_enable_x64", True) |
4 | | -from gsd.experimental.max_entropy import MaxEntropyGSD |
5 | 6 | import unittest # noqa: E402 |
6 | 7 |
|
7 | | -import jax |
8 | | -import jax.numpy as jnp |
9 | 8 | import numpy as np |
10 | 9 |
|
11 | 10 | import gsd |
|
14 | 13 | from gsd.experimental.fit import GridEstimator |
15 | 14 | from gsd.fit import log_pmax, pairs, pmax, GSDParams, fit_moments |
16 | 15 |
|
| 16 | +import equinox as eqx |
| 17 | +import optimistix as optx |
| 18 | + |
| 19 | +from gsd.experimental.max_entropy import MaxEntropyGSD, vmax |
| 20 | + |
| 21 | +import jax |
| 22 | +import jax.numpy as jnp |
| 23 | + |
17 | 24 |
|
18 | 25 | class FitTestCase(unittest.TestCase): |
19 | 26 | def test_pairs(self): |
@@ -117,3 +124,49 @@ def test_probs(self): |
117 | 124 | lp = me.all_log_probs |
118 | 125 | p = np.exp(lp) |
119 | 126 | self.assertAlmostEqual(p.sum(), 1) |
| 127 | + |
| 128 | + |
| 129 | + def test_fit(self): |
| 130 | + def nll(d, x): |
| 131 | + m, s = d |
| 132 | + mean = 1.0 + 4.0 * jax.nn.sigmoid(m) |
| 133 | + svmin = make_softvmin(0.1) |
| 134 | + smin = jnp.sqrt(svmin(mean)) |
| 135 | + smax = jnp.sqrt(vmax(mean, N=5)) |
| 136 | + sigma = smin + (smax - smin) * jax.nn.sigmoid(s) |
| 137 | + d = MaxEntropyGSD(mean, sigma, N=5) |
| 138 | + return -jnp.mean(d.log_prob(x)) |
| 139 | + |
| 140 | + # x = jnp.asarray([2, 3, 2, 2, 3, 3, 4]) |
| 141 | + x = jnp.asarray([2, 2, 2, 2, 2, 2, 2]) |
| 142 | + |
| 143 | + eqx.tree_pprint(jax.grad(nll)((0.01, 2.0), x), short_arrays=False) |
| 144 | + |
| 145 | + def fit(x): |
| 146 | + solver = optx.BFGS(rtol=1e-2, atol=1e-4) |
| 147 | + |
| 148 | + res = optx.minimise(nll, solver, (-0.0, .0), |
| 149 | + args=x, |
| 150 | + max_steps=int(1e6), |
| 151 | + throw=True) |
| 152 | + return res |
| 153 | + |
| 154 | + res = jax.jit(fit)(x) |
| 155 | + eqx.tree_pprint(res.value, short_arrays=False) |
| 156 | + |
| 157 | + m, s = res.value |
| 158 | + mean = 1.0 + 4.0 * jax.nn.sigmoid(m) |
| 159 | + smin = jnp.sqrt(vmin(mean)) |
| 160 | + smax = jnp.sqrt(vmax(mean, N=5)) |
| 161 | + sigma = smin + (smax - smin) * jax.nn.sigmoid(s) |
| 162 | + d = MaxEntropyGSD(mean, sigma, N=5) |
| 163 | + |
| 164 | + self.assertAlmostEqual(d.mean,2., places=4) |
| 165 | + |
| 166 | + eqx.tree_pprint(d, short_arrays=False) |
| 167 | + eqx.tree_pprint(MaxEntropyGSD(jnp.mean(x), jnp.std(x), N=5), |
| 168 | + short_arrays=False) |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | + |
0 commit comments