Skip to content

Commit 25fd150

Browse files
Add maximum entropy distribution
1 parent 79a9867 commit 25fd150

File tree

7 files changed

+187
-11
lines changed

7 files changed

+187
-11
lines changed

docs/api.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,24 @@ Besides the high-level API one can use optimizers form `scipy` or `tensorflow_pr
7878

7979
::: gsd.experimental.OptState
8080

81+
### Maximum entropy
82+
83+
GSD distribution can be considered as the whole family of distributions
84+
with the following properties:
85+
86+
1. Its distribution over $[1,N]$
87+
2. The first parameter represents expectation value
88+
3. It covers all possible variances
89+
90+
Another distribution that has similar properties and can be considered a member
91+
of GSD family is maximum entropy distribution.
92+
93+
::: gsd.experimental.MaxEntropyGSD
94+
:docstring:
95+
:members: __init__
96+
97+
98+
8199

82100

83101

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
site_name: GSD
22
site_description: The documentation for the reference implementation of generalised score distribution in python.
3+
repo_url: https://github.com/gsd-authors/gsd
4+
repo_name: gsd-authors/gsd
35

46
theme:
57
name: "material"

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ include = [
4848
[tool.hatch.envs.default]
4949
dependencies=["jaxlib>=0.4.6"]
5050

51+
[project.optional-dependencies]
52+
experimental = [
53+
"optimistix>=0.0.6",
54+
]
55+
5156
[tool.hatch.envs.default.scripts]
5257
test = "python -m unittest discover -p '*test.py'"
5358

src/gsd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.2.1dev'
1+
__version__ = '0.2.1'
22
from gsd.fit import GSDParams as GSDParams
33
from gsd.fit import fit_moments as fit_moments
44
from gsd.gsd import (log_prob as log_prob,

src/gsd/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
from .fit import OptState as OptState
99
from .fit import fit_mle as fit_mle
1010
from .fit import fit_mle_grid as fit_mle_grid
11+
12+
from .max_entropy import MaxEntropyGSD as MaxEntropyGSD
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import equinox as eqx
2+
import jax
3+
import jax.numpy as jnp
4+
import optimistix as optx
5+
from jaxtyping import Array, Float, Int, PRNGKeyArray
6+
7+
import gsd
8+
from gsd import GSDParams
9+
from gsd.gsd import vmin
10+
11+
12+
@jax.jit
13+
def vmax(mean: Array, N: Int) -> Array:
14+
"""
15+
Computes maximal variance for categorical distribution supported on Z[1,N]
16+
:param mean:
17+
:param N:
18+
:return:
19+
"""
20+
return (mean - 1.0) * (N - mean)
21+
22+
23+
def _lagrange_log_probs(lagrage: tuple, dist: 'MaxEntropyGSD'):
24+
lamda1, lamdam, lamdas = lagrage
25+
lp = lamda1 + dist.support * lamdam + lamdas * dist.squred_diff - 1.0
26+
return lp
27+
28+
29+
def _implicit_log_probs(lagrage: tuple, d: 'MaxEntropyGSD'):
30+
lp = _lagrange_log_probs(lagrage, d)
31+
p = jnp.exp(lp)
32+
return (jnp.sum(p) - 1.0, # jax.nn.logsumexp(lp),
33+
jnp.dot(p, d.support) - d.mean,
34+
# jax.nn.logsumexp(a=lp, b=d.support) - jnp.log(d.mean),
35+
jnp.dot(p, d.squred_diff) - d.sigma ** 2,
36+
# jax.nn.logsumexp(a=lp, b=d.squred_diff) - 2 * jnp.log(d.sigma)
37+
)
38+
39+
40+
def _explicit_log_probs(dist: 'MaxEntropyGSD'):
41+
solver = optx.Newton(rtol=1e-8, atol=1e-8, )
42+
43+
lgr = jax.tree_util.tree_map(jnp.asarray, (-0.01, -0.01, -0.01))
44+
sol = optx.root_find(_implicit_log_probs, solver, lgr, args=dist,
45+
max_steps=int(1e4), throw=False)
46+
return _lagrange_log_probs(sol.value, dist)
47+
48+
49+
class MaxEntropyGSD(eqx.Module):
50+
r"""
51+
Maximum entropy distribution supported on `Z[1,N]`
52+
53+
This distribution is defined to fulfill the following conditions on $p_i$
54+
55+
* Maximize $H= -\sum_i p_i\log(p_i)$ wrt.
56+
* $\sum p_i=1$
57+
* $\sum i p_i= \mu$
58+
* $\sum (i-\mu)^2 p_i= \sigma^2$
59+
60+
:param mean: Expectation value of the distribution.
61+
:param sigma: Standard deviation of the distribution.
62+
:param N: Number of responses
63+
64+
"""
65+
mean: Float[Array, ""]
66+
sigma: Float[Array, ""] # std
67+
N: int = eqx.field(static=True)
68+
69+
70+
def log_prob(self, x: Int[Array, ""]):
71+
lp = _explicit_log_probs(self)
72+
return lp[x - 1]
73+
74+
def prob(self, x: Int[Array, ""]):
75+
return jnp.exp(self.log_prob(x))
76+
77+
@property
78+
def support(self):
79+
return jnp.arange(1, self.N + 1)
80+
81+
@property
82+
def squred_diff(self):
83+
return jnp.square((self.support - self.mean))
84+
85+
def stddev(self):
86+
return jnp.sqrt(self.variance())
87+
88+
def vmax(self):
89+
return (self.mean - 1.0) * (self.N - self.mean)
90+
91+
def vmin(self):
92+
return vmin(self.mean)
93+
94+
@property
95+
def all_log_probs(self):
96+
lp = _explicit_log_probs(self)
97+
return lp
98+
99+
@jax.jit
100+
def entropy(self):
101+
lp = self.all_log_probs
102+
return -jnp.dot(lp, jnp.exp(lp))
103+
104+
def sample(self, key: PRNGKeyArray, axis=-1, shape=None):
105+
lp = self.all_log_probs
106+
return jax.random.categorical(key, lp, axis, shape) + self.support[0]
107+
108+
@staticmethod
109+
def from_gsd(theta:GSDParams, N:int) -> 'MaxEntropyGSD':
110+
"""Created maxentropy from GSD parameters.
111+
112+
:param theta: Parameters of a GSD distribution.
113+
:param N: Support size
114+
:return: A distribution object
115+
"""
116+
return MaxEntropyGSD(
117+
mean=gsd.mean(theta.psi, theta.rho),
118+
sigma=jnp.sqrt(gsd.variance(theta.psi, theta.rho)),
119+
N=N
120+
)
121+
122+
MaxEntropyGSD.__init__.__doc__ = """Creates a MaxEntropyGSD
123+
124+
:param mean: Expectation value of the distribution.
125+
:param sigma: Standard deviation of the distribution.
126+
:param N: Number of responses
127+
128+
.. note::
129+
An alternative way to construct this distribution is by use of
130+
:ref:`from_gsd`
131+
132+
"""

tests/experimental_test.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from jax import config
2-
config.update("jax_enable_x64", True)
32

3+
config.update("jax_enable_x64", True)
4+
from gsd.experimental.max_entropy import MaxEntropyGSD
45
import unittest # noqa: E402
56

67
import jax
@@ -45,7 +46,8 @@ def test_fit_grid3(self):
4546
data = jnp.asarray([7, 25., 0, 0, 0])
4647
hat = est(data)
4748
theta = fit_mle_grid(data, num, False)
48-
jax.tree_util.tree_map(lambda a,b: self.assertAlmostEqual(a,b,2), hat, theta)
49+
jax.tree_util.tree_map(lambda a, b: self.assertAlmostEqual(a, b, 2),
50+
hat, theta)
4951

5052
...
5153

@@ -68,7 +70,7 @@ def test_sample_fit(self):
6870
k = jax.random.key(12)
6971
th = GSDParams(psi=4.2, rho=.92)
7072
th = jax.tree_util.tree_map(jnp.asarray, th)
71-
s = gsd.sample(th.psi, th.rho, (100000,),k)
73+
s = gsd.sample(th.psi, th.rho, (100000,), k)
7274
data = gsd.sufficient_statistic(s)
7375
num = GSDParams(512, 128)
7476
grid = GridEstimator.make(num)
@@ -79,24 +81,39 @@ def test_sample_fit(self):
7981

8082
def test_g_test(self):
8183
# https://github.com/Qub3k/gsd-acm-mm/blob/master/Data_Analysis/G-test_results/G_test_on_real_data_chunk000_of_872.csv
82-
data = jnp.asarray([0,0,1,10,13.])
84+
data = jnp.asarray([0, 0, 1, 10, 13.])
8385
num = GSDParams(512, 128)
8486
grid = GridEstimator.make(num)
8587

86-
8788
hat = grid(data)
8889
self.assertTrue(np.allclose(hat.psi, 4.5, 0.001))
8990
self.assertTrue(np.allclose(hat.rho, 0.935, 0.01))
9091

9192
p = bootstrap.prob(hat)
9293
# 0.09459716927725387
93-
t = bootstrap.t_statistic(data,p)
94-
self.assertAlmostEqual(t,0.09459716927725387,2)
94+
t = bootstrap.t_statistic(data, p)
95+
self.assertAlmostEqual(t, 0.09459716927725387, 2)
9596

9697
# 0.4957
97-
pv = bootstrap.pp_plot_data(data,lambda x: grid(x) ,jax.random.key(44),9999)
98+
pv = bootstrap.pp_plot_data(data, lambda x: grid(x),
99+
jax.random.key(44), 9999)
100+
101+
self.assertAlmostEqual(pv, 0.4957, 1)
102+
103+
...
104+
98105

99-
self.assertAlmostEqual(pv,0.4957,1)
106+
class MaxEntropyTestCase(unittest.TestCase):
107+
def test_maxentropy(self):
108+
me = MaxEntropyGSD(mean=3.2, sigma=0.2, N=5)
109+
self.assertAlmostEqual(me.mean, 3.2)
100110

111+
s = me.sample(jax.random.key(44))
112+
s2 = me.sample(jax.random.key(44), shape=(5,))
113+
self.assertAlmostEqual(s2.shape[0], 5)
101114

102-
...
115+
def test_probs(self):
116+
me = MaxEntropyGSD.from_gsd(GSDParams(psi=3.2, rho=0.9), 5)
117+
lp = me.all_log_probs
118+
p = np.exp(lp)
119+
self.assertAlmostEqual(p.sum(), 1)

0 commit comments

Comments
 (0)