Skip to content

Commit 40db0c8

Browse files
Merge pull request #3 from gsd-authors/fit
Fit
2 parents e2bf27f + 74e2916 commit 40db0c8

File tree

7 files changed

+236
-28
lines changed

7 files changed

+236
-28
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,18 @@ $ hatch shell
4343

4444
### Running tests
4545

46-
Gsd uses [Pytest](https://pytest.org/) for testing. To run the tests, use the following command:
46+
Gsd uses unitest for testing. To run the tests, use the following command:
4747

4848
```
4949
$ hatch run test
5050
```
51+
52+
### Standalone estimator
53+
54+
You can quickly estimate GSD parameters from a command line interface
55+
56+
```shell
57+
python3 -m gsd 0 12 13 4 0
58+
```
59+
60+
GSDParams(psi=Array(2.6272388, dtype=float32), rho=Array(0.9041536, dtype=float32))

src/gsd/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__version__ = '0.0.4'
1+
__version__ = '0.0.5'
22

3+
from gsd.fit import GSDParams
4+
from gsd.fit import fit_mle
5+
from gsd.fit import fit_moments
36
from gsd.gsd import (log_prob, sample, mean, variance)
47
from gsd.ref_prob import gsd_prob

src/gsd/__main__.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
import gsd
2-
import jax
1+
import argparse
32
import jax.numpy as jnp
4-
3+
from gsd import fit_mle
54

65
if __name__ == '__main__':
7-
gsd.log_prob(1., 0.5, 2)
8-
m=gsd.mean(3.,0.7)
9-
v = gsd.variance(3.,0.7)
10-
k = jax.random.key(43)
11-
s = gsd.sample(3.,0.7,(24,),k)
12-
13-
jnp.mean(s), jnp.var(s)
14-
15-
#jax.vmap(gsd.log_prob, in_axes=(None,None,0))(3.,0.7,s)
16-
6+
parser = argparse.ArgumentParser(description='GSD estimator')
177

8+
parser.add_argument("response", nargs=5, type=int,
9+
metavar=("num1", "num2", "num3", "num4", "num5"),
10+
help="List of 5 counts")
1811

12+
args = parser.parse_args()
1913

20-
print('test')
14+
hat,_ = fit_mle(data=jnp.asarray(args.response, dtype=jnp.float32))
15+
print(hat)

src/gsd/fit.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import NamedTuple
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import jax.tree_util as jtu
6+
from jax import Array
7+
from jax.typing import ArrayLike
8+
9+
from .gsd import vmax, vmin, log_prob
10+
11+
12+
class GSDParams(NamedTuple):
13+
"""NamedTuple representing parameters for the Generalized Structure Distribution (GSD).
14+
15+
This class is used to store the psi and rho parameters for the GSD.
16+
It provides a convenient way to group these parameters together for use in various
17+
statistical and modeling applications.
18+
"""
19+
psi: Array
20+
rho: Array
21+
22+
23+
@jax.jit
24+
def fit_moments(data: ArrayLike) -> GSDParams:
25+
"""Fits GSD using moments estimator
26+
27+
:param data: A 5d Array of counts of each response.
28+
:return: GSD Parameters
29+
"""
30+
psi = jnp.dot(data, jnp.arange(1, 6)) / jnp.sum(data)
31+
V = jnp.dot(data, jnp.arange(1, 6) ** 2) / jnp.sum(data) - psi ** 2
32+
return GSDParams(psi=psi, rho=(vmax(psi) - V) / (vmax(psi) - vmin(psi)))
33+
34+
35+
class OptState(NamedTuple):
36+
"""A class representing the state of an optimization process.
37+
38+
Attributes:
39+
:param params (GSDParams): The current optimization parameters.
40+
:param previous_params (GSDParams): The previous optimization parameters.
41+
:param count (int): An integer count indicating the step or iteration of the optimization process.
42+
43+
This class is used to store and manage the state of an optimization algorithm, allowing you
44+
to keep track of the current parameters, previous parameters, and the step count.
45+
46+
"""
47+
params: GSDParams
48+
previous_params: GSDParams
49+
count: int
50+
51+
52+
@jax.jit
53+
def fit_mle(data: ArrayLike, max_iterations: int = 100, log_lr_min: ArrayLike = -15, log_lr_max: ArrayLike = 2.,
54+
num_lr: ArrayLike = 10) -> tuple[GSDParams, OptState]:
55+
"""Finds the maximum likelihood estimator of the GSD parameters.
56+
The algorithm used here is a simple gradient ascent.
57+
We use the concept of projected gradient to enforce constraints for parameters (psi in [1, 5], rho in [0, 1]) and exhaustive search for line search along the gradient.
58+
59+
:param data: 5D array of counts for each response.
60+
:param max_iterations: Maximum number of iterations.
61+
:param log_lr_min: Log2 of the smallest learning rate.
62+
:param log_lr_max: Log2 of the largest learning rate.
63+
:param num_lr: Number of learning rates to check during the line search.
64+
65+
:return: An opt state whore params filed contains estimated values of GSD Parameters
66+
"""
67+
68+
def ll(theta: GSDParams) -> Array:
69+
logits = jax.vmap(log_prob, (None, None, 0), (0))(theta.psi, theta.rho, jnp.arange(1, 6))
70+
return jnp.dot(data, logits) / jnp.sum(data)
71+
72+
grad_ll = jax.grad(ll)
73+
theta0 = fit_moments(data)
74+
rate = jnp.concatenate([jnp.zeros((1,)), jnp.logspace(log_lr_min, log_lr_max, num_lr, base=2.)])
75+
76+
def update(tg, t, lo, hi):
77+
'''
78+
:param tg: gradient
79+
:param t: theta parameters
80+
:param lo: lower bound
81+
:param hi: upper bound
82+
:return: updated params
83+
'''
84+
nt = t + rate * jnp.where(jnp.isnan(tg), 0., tg)
85+
_nt = jnp.where(nt < lo, lo, nt)
86+
_nt = jnp.where(_nt > hi, hi, _nt)
87+
return _nt
88+
89+
lo = GSDParams(psi=1., rho=0.)
90+
hi = GSDParams(psi=5., rho=1.)
91+
92+
def body_fun(state: OptState) -> OptState:
93+
t, _, count = state
94+
g = grad_ll(t)
95+
new_theta = jtu.tree_map(update, g, t, lo, hi)
96+
new_lls = jax.vmap(ll)(new_theta)
97+
max_idx = jnp.argmax(jnp.where(jnp.isnan(new_lls), -jnp.inf, new_lls))
98+
return OptState(params=jtu.tree_map(lambda t: t[max_idx], new_theta), previous_params=t, count=count + 1)
99+
100+
def cond_fun(state: OptState) -> bool:
101+
tn, tnm1, c = state
102+
should_stop = jnp.logical_or(jnp.all(jnp.array(tn) == jnp.array(tnm1)), c > max_iterations)
103+
# stop on NaN
104+
should_stop = jnp.logical_or(should_stop, jnp.any(jnp.isnan(jnp.array(tn))))
105+
return jnp.logical_not(should_stop)
106+
107+
opt_state = jax.lax.while_loop(cond_fun, body_fun,
108+
OptState(params=theta0, previous_params=jtu.tree_map(lambda _: jnp.inf, theta0),
109+
count=0))
110+
return opt_state.params, opt_state
111+

src/gsd/gsd.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,38 @@
1-
from typing import Any, Sequence
1+
from typing import Sequence
22

33
import jax
44
import jax.numpy as jnp
5-
65
import numpy as np
7-
from jax.scipy.special import betaln
8-
from jax.typing import ArrayLike
96
from jax import Array
107
from jax.random import PRNGKeyArray
8+
from jax.scipy.special import betaln
9+
from jax.typing import ArrayLike
1110

1211
Shape = Sequence[int]
1312

1413
N = 5
1514

1615

1716
def logbinom(n: ArrayLike, k: ArrayLike) -> Array:
18-
'''
19-
log of `n choose k`
20-
'''
17+
""" Stable log of `n choose k` """
2118
return -jnp.log1p(n) - betaln(n - k + 1., k + 1.)
2219

2320

2421
def vmin(psi: ArrayLike) -> Array:
22+
"""Compute the minimal possible variance for give mean
23+
24+
:param psi: mean
25+
:return: variance
26+
"""
2527
return (jnp.ceil(psi) - psi) * (psi - jnp.floor(psi))
2628

2729

2830
def vmax(psi: ArrayLike) -> Array:
31+
"""Compute the maximal possible variance for give mean
32+
33+
:param psi: mean
34+
:return: variance
35+
"""
2936
return (psi - 1.) * (5 - psi)
3037

3138

@@ -34,6 +41,13 @@ def _C(Vmax: ArrayLike, Vmin: ArrayLike) -> Array:
3441

3542

3643
def log_prob(psi: ArrayLike, rho: ArrayLike, k: ArrayLike) -> Array:
44+
"""Compute log probability of the response k for given parameters.
45+
46+
:param psi: mean
47+
:param rho: dispersion
48+
:param k: response
49+
:return: log of the probability in GSD distribution
50+
"""
3751
index = jnp.arange(0, 6)
3852
almost_neg_inf = np.log(1e-10)
3953

@@ -48,33 +62,41 @@ def log_prob(psi: ArrayLike, rho: ArrayLike, k: ArrayLike) -> Array:
4862
b0 = jnp.log(jnp.zeros_like(index))
4963
b0 = b0.at[0].set(jnp.log((5. - psi) / 4.))
5064
b0 = b0.at[4].set(jnp.log((psi - 1.) / 4.))
51-
beta_bin_part = jnp.where(rho == 0.0, b0[k-1],beta_bin_part)
52-
65+
beta_bin_part = jnp.where(rho == 0.0, b0[k - 1], beta_bin_part)
5366

5467
min_var_part = jax.nn.relu(1. - jnp.abs(k - psi))
5568
log_min_var_part = jnp.where(rho < C, 0., jnp.log(rho - C)) - jnp.log1p(-C) + jnp.log(min_var_part)
5669
log_bin_part = jnp.log1p(-rho) - jnp.log1p(-C) + logbinom(4, k - 1.) + (k - 1) * (jnp.log(psi - 1) - jnp.log(4)) + (
5770
5 - k) * (jnp.log(5 - psi) - jnp.log(4))
5871

59-
6072
logmix = jnp.logaddexp(jnp.where(min_var_part == 0, almost_neg_inf, log_min_var_part), log_bin_part)
6173

62-
logmix = jnp.where(rho==1.0,jnp.log(min_var_part),logmix)
63-
#logmix = jnp.where(min_var_part == 0, log_bin_part, logmix)
74+
logmix = jnp.where(rho == 1.0, jnp.log(min_var_part), logmix)
75+
# logmix = jnp.where(min_var_part == 0, log_bin_part, logmix)
6476

6577
return jnp.where(rho < C, beta_bin_part, logmix)
6678

6779

6880
def mean(psi: ArrayLike, rho: ArrayLike) -> Array:
81+
"""Mean of GSD distribution"""
6982
del rho
7083
return psi
7184

7285

7386
def variance(psi: ArrayLike, rho: ArrayLike) -> Array:
87+
"""Variance of GSD distribution"""
7488
return rho * vmin(psi) + (1 - rho) * vmax(psi)
7589

7690

7791
def sample(psi: ArrayLike, rho: ArrayLike, shape: Shape, key: PRNGKeyArray) -> Array:
92+
"""Sample from GSD
93+
94+
:param psi: mean
95+
:param rho: dispersion
96+
:param shape: sample shape
97+
:param key: random key
98+
:return: Array of shape :param shape:
99+
"""
78100
index = jnp.arange(1, N + 1)
79101
logits = jax.vmap(log_prob, in_axes=(None, None, 0))(psi, rho, index)
80102
return jax.random.categorical(key, logits, shape=shape) + 1

tests/fit.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from jax import config
2+
3+
config.update("jax_enable_x64", True)
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
from scipy.integrate import dblquad
8+
9+
from gsd import log_prob
10+
11+
if __name__ == '__main__':
12+
data = jnp.asarray([5, 12, 3, 0, 0])
13+
k = jnp.arange(1, 6)
14+
15+
16+
@jax.jit
17+
def posterior(psi, rho):
18+
log_posterior = jax.vmap(log_prob, in_axes=(None, None, 0))(psi, rho, k) @ data + 1. + 1 / 4.
19+
posterior = jnp.exp(log_posterior)
20+
return posterior
21+
22+
23+
epsabs = 1e-14
24+
epsreal = 1e-11
25+
26+
Z, Zerr = dblquad(posterior, a=0, b=1, gfun=lambda x: 1., hfun=lambda x: 5., epsabs=epsabs, epsrel=epsreal)
27+
psi_hat, _ = dblquad(jax.jit(lambda psi, rho: psi * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
28+
hfun=lambda x: 5.,
29+
epsabs=epsabs, epsrel=epsreal)
30+
psi_hat = psi_hat / Z
31+
rho_hat, _ = dblquad(jax.jit(lambda psi, rho: rho * posterior(psi, rho)), a=0, b=1, gfun=lambda x: 1.,
32+
hfun=lambda x: 5.,
33+
epsabs=epsabs, epsrel=epsreal)
34+
rho_hat = rho_hat / Z
35+
36+
psi_ci, _ = dblquad(jax.jit(lambda psi, rho: (psi_hat - psi) ** 2 * posterior(psi, rho)), a=0, b=1,
37+
gfun=lambda x: 1., hfun=lambda x: 5.,
38+
epsabs=epsabs, epsrel=epsreal)
39+
40+
psi_ci = np.sqrt(psi_ci / Z)
41+
42+
rho_ci, _ = dblquad(jax.jit(lambda psi, rho: (rho_hat - rho) ** 2 * posterior(psi, rho)), a=0, b=1,
43+
gfun=lambda x: 1., hfun=lambda x: 5.,
44+
epsabs=epsabs, epsrel=epsreal)
45+
46+
rho_ci = np.sqrt(rho_ci / Z)
47+
48+
k @ data / data.sum()
49+
pass

tests/fit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import unittest
2+
import jax.numpy as jnp
3+
4+
import gsd.fit
5+
6+
7+
class FitTestCase(unittest.TestCase):
8+
def test_mle(self):
9+
# 1 2 3 4 5
10+
data=jnp.asarray([0,10,10,0,0.])
11+
_,os = gsd.fit.fit_mle(data)
12+
self.assertAlmostEqual(os.params.psi, 2.5)
13+
self.assertAlmostEqual(os.params.rho, 1)
14+
15+
16+
17+
if __name__ == '__main__':
18+
unittest.main()

0 commit comments

Comments
 (0)