Skip to content

Commit f64b539

Browse files
Merge pull request #20 from gsd-authors/soft
Soft
2 parents 7ec8ae9 + 3111bc0 commit f64b539

File tree

8 files changed

+393
-12
lines changed

8 files changed

+393
-12
lines changed

discussion/softvmin.wl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
(* ExecuteFile["softvmin.wl"] *)
2+
3+
Clear["Global`*"]
4+
5+
fa[x_] := (2-x)(x-1)
6+
fb[x_] := (3-x)(x-2)
7+
8+
Plot[{fa[x],fb[x]},{x,1,3}] //Export["fafb.pdf", # ]&
9+
10+
11+
pows = {0,2,4}
12+
vars=Subscript[b,#]&/@pows
13+
14+
appf[x_] := Total[Subscript[b,#] (x-2)^# &/@pows]
15+
16+
eqs={
17+
appf[2-d]==fa[2-d],
18+
appf[2+d]==fb[2+d],
19+
D[fa[x],x]==D[appf[x],x]/.{x->2-d},
20+
D[fb[x],x]==D[appf[x],x]/.{x->2+d},
21+
D[fa[x],{x,2}]==D[appf[x],{x,2}]/.{x->2-d},
22+
D[fb[x],{x,2}]==D[appf[x],{x,2}]/.{x->2+d},
23+
D[appf[x],x]==0/.{x->2}
24+
}
25+
26+
sol=Solve[
27+
eqs,
28+
vars
29+
]
30+
31+
(* sol = vars/.Solve[
32+
eqs/.{d->0.1},
33+
vars
34+
] *)
35+
36+
sol=sol[[1]]
37+
38+
39+
Plot[{(appf[x]/.sol)/.{d->1/50}, fa[x],fb[x]},{x,1.8,2.2}, PlotRange->{0,1/4}]//Export["appf.pdf", # ]&
40+
41+
Export["sol.txt",(appf[x]/.sol)]
42+
43+
(* Needs["CCodeGenerator`"]
44+
45+
CCodeGenerator[]
46+
47+
48+
c = Compile[ {{x},{d}}, appf[x]/.sol];
49+
file = CCodeStringGenerate[c, "fun"] *)
50+
51+
(* Test cases *)
52+
53+
(appf[x]/.sol)/.{d->1/50, x->1.99}
54+
55+
(appf[x]/.sol)/.{d->1/10, x->2.05}
56+
57+
p = (n+1)/(n+2)
58+
ep = p x + (1-p)(x+1)
59+
60+
v = Simplify[p (x-ep)^2 + (1-p)(x+1-ep)^2]
61+
62+
ExportString[v,"tex"]
63+
64+
sol = Solve[((appf[x]/.sol)/.{x->2})==v,d]
65+
66+
ExportString[sol,"tex"]
67+
68+
N[(d/.sol[[1]])/.{n->24}]
69+

examples/softvmin.ipynb

Lines changed: 202 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727

2828
# HATCH_PYTHON=python3.10
2929
requires-python = ">=3.10"
30-
dependencies=["jax>=0.4.6"]
30+
dependencies=["jax>=0.4.23"]
3131

3232
[project.urls]
3333
Homepage = "https://github.com/gsd-authors/gsd"
@@ -46,7 +46,7 @@ include = [
4646
]
4747

4848
[tool.hatch.envs.default]
49-
dependencies=["jaxlib>=0.4.6"]
49+
dependencies=["jaxlib>=0.4.23"]
5050

5151
[project.optional-dependencies]
5252
experimental = [

src/gsd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.2.2dev'
1+
__version__ = '0.2.2'
22
from gsd.fit import GSDParams as GSDParams
33
from gsd.fit import fit_moments as fit_moments
44
from gsd.gsd import (log_prob as log_prob,

src/gsd/experimental/max_entropy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _explicit_log_probs(dist: 'MaxEntropyGSD'):
4242

4343
lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
4444
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
45-
max_steps=int(1e4), throw=False)
45+
max_steps=int(1e4), throw=True)
4646
return _lagrange_log_probs(sol.value, dist)
4747

4848

@@ -66,7 +66,6 @@ class MaxEntropyGSD(eqx.Module):
6666
sigma: Float[Array, ""] # std
6767
N: int = eqx.field(static=True)
6868

69-
7069
def log_prob(self, x: Int[Array, ""]):
7170
lp = _explicit_log_probs(self)
7271
return lp[x - 1]
@@ -106,7 +105,7 @@ def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
106105
return jax.random.categorical(key, lp, axis, shape) + self.support[0]
107106

108107
@staticmethod
109-
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
108+
def from_gsd(theta: GSDParams, N: int) -> 'MaxEntropyGSD':
110109
"""Created maxentropy from GSD parameters.
111110
112111
:param theta: Parameters of a GSD distribution.
@@ -119,6 +118,7 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
119118
N=N
120119
)
121120

121+
122122
MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD
123123
124124
:param mean: Expectation value of the distribution.
@@ -127,6 +127,6 @@ def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
127127
128128
.. note::
129129
An alternative way to construct this distribution is by use of
130-
:ref:`from_gsd`
130+
:meth:`from_gsd`
131131
132-
"""
132+
"""

src/gsd/gsd.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence
1+
from typing import Sequence, Callable
22

33
import jax
44
import jax.numpy as jnp
@@ -141,3 +141,33 @@ def sufficient_statistic(data: ArrayLike) -> Array:
141141
bins = jnp.arange(0.5, N + 1.5, 1.)
142142
c, _ = jnp.histogram(jnp.asarray(data), bins=bins)
143143
return c
144+
145+
146+
def softvmin_poly(x: Array, c: float, d: float) -> Array:
147+
"""Smooths approximation to `vmin` function.
148+
149+
:param x: An argument, this would be psi
150+
:param d: Cut point of approximation from `[0,0.5)`
151+
:return: An approximated value `x` such that `abs(round(x)-x)<=d`
152+
"""
153+
sq1 = jnp.square(x - c)
154+
sq2 = jnp.square(sq1)
155+
156+
return (3 * d) / 8 - ((-3 + 4 * d) * sq1) / (4 * d) - sq2 / (8 * d ** 3)
157+
158+
159+
def make_softvmin(d: float) -> Callable[[Array], Array]:
160+
"""Create a soft approximation to `vmin` function.
161+
162+
:param d: Cut point of approximation from `[0,0.5)`
163+
:return: A callable returning n approximated value `vmin` for `x`
164+
`abs(round(x)-x)<=d`
165+
"""
166+
def sofvmin(psi: ArrayLike):
167+
psi = jnp.asarray(psi)
168+
c = jax.lax.stop_gradient(jnp.round(psi))
169+
return jnp.where(jnp.abs(psi - c) < d, softvmin_poly(psi, c, d),
170+
vmin(psi)
171+
)
172+
173+
return sofvmin

tests/experimental_test.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from jax import config
22

3+
from gsd.gsd import make_softvmin, vmax, vmin
4+
35
config.update("jax_enable_x64", True)
4-
from gsd.experimental.max_entropy import MaxEntropyGSD
56
import unittest # noqa: E402
67

7-
import jax
8-
import jax.numpy as jnp
98
import numpy as np
109

1110
import gsd
@@ -14,6 +13,14 @@
1413
from gsd.experimental.fit import GridEstimator
1514
from gsd.fit import log_pmax, pairs, pmax, GSDParams, fit_moments
1615

16+
import equinox as eqx
17+
import optimistix as optx
18+
19+
from gsd.experimental.max_entropy import MaxEntropyGSD, vmax
20+
21+
import jax
22+
import jax.numpy as jnp
23+
1724

1825
class FitTestCase(unittest.TestCase):
1926
def test_pairs(self):
@@ -117,3 +124,49 @@ def test_probs(self):
117124
lp = me.all_log_probs
118125
p = np.exp(lp)
119126
self.assertAlmostEqual(p.sum(), 1)
127+
128+
129+
def test_fit(self):
130+
def nll(d, x):
131+
m, s = d
132+
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
133+
svmin = make_softvmin(0.1)
134+
smin = jnp.sqrt(svmin(mean))
135+
smax = jnp.sqrt(vmax(mean, N=5))
136+
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
137+
d = MaxEntropyGSD(mean, sigma, N=5)
138+
return -jnp.mean(d.log_prob(x))
139+
140+
# x = jnp.asarray([2, 3, 2, 2, 3, 3, 4])
141+
x = jnp.asarray([2, 2, 2, 2, 2, 2, 2])
142+
143+
eqx.tree_pprint(jax.grad(nll)((0.01, 2.0), x), short_arrays=False)
144+
145+
def fit(x):
146+
solver = optx.BFGS(rtol=1e-2, atol=1e-4)
147+
148+
res = optx.minimise(nll, solver, (-0.0, .0),
149+
args=x,
150+
max_steps=int(1e6),
151+
throw=True)
152+
return res
153+
154+
res = jax.jit(fit)(x)
155+
eqx.tree_pprint(res.value, short_arrays=False)
156+
157+
m, s = res.value
158+
mean = 1.0 + 4.0 * jax.nn.sigmoid(m)
159+
smin = jnp.sqrt(vmin(mean))
160+
smax = jnp.sqrt(vmax(mean, N=5))
161+
sigma = smin + (smax - smin) * jax.nn.sigmoid(s)
162+
d = MaxEntropyGSD(mean, sigma, N=5)
163+
164+
self.assertAlmostEqual(d.mean,2., places=4)
165+
166+
eqx.tree_pprint(d, short_arrays=False)
167+
eqx.tree_pprint(MaxEntropyGSD(jnp.mean(x), jnp.std(x), N=5),
168+
short_arrays=False)
169+
170+
171+
172+

tests/ref_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
22
from jax import config
3+
4+
from gsd.gsd import softvmin_poly, make_softvmin, vmin
5+
36
config.update("jax_enable_x64", True)
47

58
import unittest
@@ -104,5 +107,29 @@ def test_sufficient_statistic4(self):
104107
# 1, 2 3 4 5
105108
self.assertTrue(np.allclose(ss,c))
106109

110+
111+
class SoftTestCase(unittest.TestCase):
112+
def test_poly(self):
113+
v = softvmin_poly(x=1.99,c=2., d=1/50.)
114+
self.assertAlmostEqual(v, 0.0109938)
115+
v = softvmin_poly(x=2.05,c=2, d=1 / 10.)
116+
self.assertAlmostEqual(v, 0.0529687)
117+
118+
def test_softvmin(self):
119+
svmin = make_softvmin(0.1)
120+
self.assertAlmostEqual(svmin(3.3), vmin(3.3))
121+
122+
for x in [1.5,1.9, 1.95, 2.05, 2.1, 2.2]:
123+
gsvmin = jax.grad(svmin)
124+
g = gsvmin(x)
125+
print(g)
126+
self.assertIsNotNone(g)
127+
128+
ggsvmin = jax.grad(gsvmin)
129+
gg = ggsvmin(x)
130+
print(gg)
131+
self.assertIsNotNone(gg)
132+
133+
107134
if __name__ == '__main__':
108135
unittest.main()

0 commit comments

Comments
 (0)