Skip to content

Commit 3111bc0

Browse files
Add documentation
1 parent 8c95d89 commit 3111bc0

File tree

6 files changed

+79
-18
lines changed

6 files changed

+79
-18
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727

2828
# HATCH_PYTHON=python3.10
2929
requires-python = ">=3.10"
30-
dependencies=["jax>=0.4.6"]
30+
dependencies=["jax>=0.4.23"]
3131

3232
[project.urls]
3333
Homepage = "https://github.com/gsd-authors/gsd"
@@ -46,7 +46,7 @@ include = [
4646
]
4747

4848
[tool.hatch.envs.default]
49-
dependencies=["jaxlib>=0.4.6"]
49+
dependencies=["jaxlib>=0.4.23"]
5050

5151
[project.optional-dependencies]
5252
experimental = [

src/gsd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.2.2dev'
1+
__version__ = '0.2.2'
22
from gsd.fit import GSDParams as GSDParams
33
from gsd.fit import fit_moments as fit_moments
44
from gsd.gsd import (log_prob as log_prob,

src/gsd/experimental/max_entropy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _explicit_log_probs(dist: 'MaxEntropyGSD'):
4242

4343
lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
4444
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
45-
max_steps=int(1e4), throw=False)
45+
max_steps=int(1e4), throw=True)
4646
return _lagrange_log_probs(sol.value, dist)
4747

4848

@@ -66,7 +66,6 @@ class MaxEntropyGSD(eqx.Module):
6666
sigma: Float[Array, ""] # std
6767
N: int = eqx.field(static=True)
6868

69-
7069
def log_prob(self, x: Int[Array, ""]):
7170
lp = _explicit_log_probs(self)
7271
return lp[x - 1]
@@ -106,7 +105,7 @@ def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
106105
return jax.random.categorical(key, lp, axis, shape) + self.support[0]
107106

108107
@staticmethod
109-
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
108+
def from_gsd(theta: GSDParams, N: int) -> 'MaxEntropyGSD':
110109
"""Created maxentropy from GSD parameters.
111110
112111
:param theta: Parameters of a GSD distribution.
@@ -119,6 +118,7 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
119118
N=N
120119
)
121120

121+
122122
MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD
123123
124124
:param mean: Expectation value of the distribution.
@@ -127,6 +127,6 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
127127
128128
.. note::
129129
An alternative way to construct this distribution is by use of
130-
:ref:`from_gsd`
130+
:meth:`from_gsd`
131131
132-
"""
132+
"""

src/gsd/gsd.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def sufficient_statistic(data: ArrayLike) -> Array:
143143
return c
144144

145145

146-
def softvmin_poly(x: Array, c:float, d: float) -> Array:
146+
def softvmin_poly(x: Array, c: float, d: float) -> Array:
147147
"""Smooths approximation to `vmin` function.
148148
149149
:param x: An argument, this would be psi
@@ -155,11 +155,19 @@ def softvmin_poly(x: Array, c:float, d: float) -> Array:
155155

156156
return (3 * d) / 8 - ((-3 + 4 * d) * sq1) / (4 * d) - sq2 / (8 * d ** 3)
157157

158-
def make_sofvmin(d:float)->Callable[[Array], Array]:
159-
def sofvmin(psi:ArrayLike):
158+
159+
def make_softvmin(d: float) -> Callable[[Array], Array]:
160+
"""Create a soft approximation to `vmin` function.
161+
162+
:param d: Cut point of approximation from `[0,0.5)`
163+
:return: A callable returning n approximated value `vmin` for `x`
164+
`abs(round(x)-x)<=d`
165+
"""
166+
def sofvmin(psi: ArrayLike):
160167
psi = jnp.asarray(psi)
161168
c = jax.lax.stop_gradient(jnp.round(psi))
162-
return jnp.where(jnp.abs(psi-c)<d, softvmin_poly(psi,c,d),
169+
return jnp.where(jnp.abs(psi - c) < d, softvmin_poly(psi, c, d),
163170
vmin(psi)
164171
)
165-
return sofvmin
172+
173+
return sofvmin

tests/experimental_test.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from jax import config
22

3+
from gsd.gsd import make_softvmin, vmax, vmin
4+
35
config.update("jax_enable_x64", True)
4-
from gsd.experimental.max_entropy import MaxEntropyGSD
56
import unittest # noqa: E402
67

7-
import jax
8-
import jax.numpy as jnp
98
import numpy as np
109

1110
import gsd
@@ -14,6 +13,14 @@
1413
from gsd.experimental.fit import GridEstimator
1514
from gsd.fit import log_pmax, pairs, pmax, GSDParams, fit_moments
1615

16+
import equinox as eqx
17+
import optimistix as optx
18+
19+
from gsd.experimental.max_entropy import MaxEntropyGSD, vmax
20+
21+
import jax
22+
import jax.numpy as jnp
23+
1724

1825
class FitTestCase(unittest.TestCase):
1926
def test_pairs(self):
@@ -117,3 +124,49 @@ def test_probs(self):
117124
lp = me.all_log_probs
118125
p = np.exp(lp)
119126
self.assertAlmostEqual(p.sum(), 1)
127+
128+
129+
def test_fit(self):
130+
def nll(d, x):
131+
m, s = d
132+
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
133+
svmin = make_softvmin(0.1)
134+
smin = jnp.sqrt(svmin(mean))
135+
smax = jnp.sqrt(vmax(mean, N=5))
136+
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
137+
d = MaxEntropyGSD(mean, sigma, N=5)
138+
return -jnp.mean(d.log_prob(x))
139+
140+
# x = jnp.asarray([2, 3, 2, 2, 3, 3, 4])
141+
x = jnp.asarray([2, 2, 2, 2, 2, 2, 2])
142+
143+
eqx.tree_pprint(jax.grad(nll)((0.01, 2.0), x), short_arrays=False)
144+
145+
def fit(x):
146+
solver = optx.BFGS(rtol=1e-2, atol=1e-4)
147+
148+
res = optx.minimise(nll, solver, (-0.0, .0),
149+
args=x,
150+
max_steps=int(1e6),
151+
throw=True)
152+
return res
153+
154+
res = jax.jit(fit)(x)
155+
eqx.tree_pprint(res.value, short_arrays=False)
156+
157+
m, s = res.value
158+
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
159+
smin = jnp.sqrt(vmin(mean))
160+
smax = jnp.sqrt(vmax(mean, N=5))
161+
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
162+
d = MaxEntropyGSD(mean, sigma, N=5)
163+
164+
self.assertAlmostEqual(d.mean,2., places=4)
165+
166+
eqx.tree_pprint(d, short_arrays=False)
167+
eqx.tree_pprint(MaxEntropyGSD(jnp.mean(x), jnp.std(x), N=5),
168+
short_arrays=False)
169+
170+
171+
172+

tests/ref_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from jax import config
33

4-
from gsd.gsd import softvmin_poly, make_sofvmin, vmin
4+
from gsd.gsd import softvmin_poly, make_softvmin, vmin
55

66
config.update("jax_enable_x64", True)
77

@@ -116,7 +116,7 @@ def test_poly(self):
116116
self.assertAlmostEqual(v, 0.0529687)
117117

118118
def test_softvmin(self):
119-
svmin = make_sofvmin(0.1)
119+
svmin = make_softvmin(0.1)
120120
self.assertAlmostEqual(svmin(3.3), vmin(3.3))
121121

122122
for x in [1.5,1.9, 1.95, 2.05, 2.1, 2.2]:

0 commit comments

Comments
 (0)