|
| 1 | +from typing import NamedTuple |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import jax.tree_util as jtu |
| 6 | +from jax import Array |
| 7 | +from jax.typing import ArrayLike |
| 8 | + |
| 9 | +from .gsd import vmax, vmin, log_prob |
| 10 | + |
| 11 | + |
| 12 | +class GSDParams(NamedTuple): |
| 13 | + """NamedTuple representing parameters for the Generalized Structure Distribution (GSD). |
| 14 | +
|
| 15 | + This class is used to store the psi and rho parameters for the GSD. |
| 16 | + It provides a convenient way to group these parameters together for use in various |
| 17 | + statistical and modeling applications. |
| 18 | + """ |
| 19 | + psi: Array |
| 20 | + rho: Array |
| 21 | + |
| 22 | + |
| 23 | +@jax.jit |
| 24 | +def fit_moments(data: ArrayLike) -> GSDParams: |
| 25 | + """Fits GSD using moments estimator |
| 26 | +
|
| 27 | + :param data: A 5d Array of counts of each response. |
| 28 | + :return: GSD Parameters |
| 29 | + """ |
| 30 | + psi = jnp.dot(data, jnp.arange(1, 6)) / jnp.sum(data) |
| 31 | + V = jnp.dot(data, jnp.arange(1, 6) ** 2) / jnp.sum(data) - psi ** 2 |
| 32 | + return GSDParams(psi=psi, rho=(vmax(psi) - V) / (vmax(psi) - vmin(psi))) |
| 33 | + |
| 34 | + |
| 35 | +class OptState(NamedTuple): |
| 36 | + """A class representing the state of an optimization process. |
| 37 | +
|
| 38 | + Attributes: |
| 39 | + :param params (GSDParams): The current optimization parameters. |
| 40 | + :param previous_params (GSDParams): The previous optimization parameters. |
| 41 | + :param count (int): An integer count indicating the step or iteration of the optimization process. |
| 42 | +
|
| 43 | + This class is used to store and manage the state of an optimization algorithm, allowing you |
| 44 | + to keep track of the current parameters, previous parameters, and the step count. |
| 45 | +
|
| 46 | + """ |
| 47 | + params: GSDParams |
| 48 | + previous_params: GSDParams |
| 49 | + count: int |
| 50 | + |
| 51 | + |
| 52 | +@jax.jit |
| 53 | +def fit_mle(data: ArrayLike, max_iterations: int = 100, log_lr_min: ArrayLike = -15, log_lr_max: ArrayLike = 2., |
| 54 | + num_lr: ArrayLike = 10) -> tuple[GSDParams, OptState]: |
| 55 | + """Finds the maximum likelihood estimator of the GSD parameters. |
| 56 | + The algorithm used here is a simple gradient ascent. |
| 57 | + We use the concept of projected gradient to enforce constraints for parameters (psi in [1, 5], rho in [0, 1]) and exhaustive search for line search along the gradient. |
| 58 | +
|
| 59 | + :param data: 5D array of counts for each response. |
| 60 | + :param max_iterations: Maximum number of iterations. |
| 61 | + :param log_lr_min: Log2 of the smallest learning rate. |
| 62 | + :param log_lr_max: Log2 of the largest learning rate. |
| 63 | + :param num_lr: Number of learning rates to check during the line search. |
| 64 | +
|
| 65 | + :return: An opt state whore params filed contains estimated values of GSD Parameters |
| 66 | + """ |
| 67 | + |
| 68 | + def ll(theta: GSDParams) -> Array: |
| 69 | + logits = jax.vmap(log_prob, (None, None, 0), (0))(theta.psi, theta.rho, jnp.arange(1, 6)) |
| 70 | + return jnp.dot(data, logits) / jnp.sum(data) |
| 71 | + |
| 72 | + grad_ll = jax.grad(ll) |
| 73 | + theta0 = fit_moments(data) |
| 74 | + rate = jnp.concatenate([jnp.zeros((1,)), jnp.logspace(log_lr_min, log_lr_max, num_lr, base=2.)]) |
| 75 | + |
| 76 | + def update(tg, t, lo, hi): |
| 77 | + ''' |
| 78 | + :param tg: gradient |
| 79 | + :param t: theta parameters |
| 80 | + :param lo: lower bound |
| 81 | + :param hi: upper bound |
| 82 | + :return: updated params |
| 83 | + ''' |
| 84 | + nt = t + rate * jnp.where(jnp.isnan(tg), 0., tg) |
| 85 | + _nt = jnp.where(nt < lo, lo, nt) |
| 86 | + _nt = jnp.where(_nt > hi, hi, _nt) |
| 87 | + return _nt |
| 88 | + |
| 89 | + lo = GSDParams(psi=1., rho=0.) |
| 90 | + hi = GSDParams(psi=5., rho=1.) |
| 91 | + |
| 92 | + def body_fun(state: OptState) -> OptState: |
| 93 | + t, _, count = state |
| 94 | + g = grad_ll(t) |
| 95 | + new_theta = jtu.tree_map(update, g, t, lo, hi) |
| 96 | + new_lls = jax.vmap(ll)(new_theta) |
| 97 | + max_idx = jnp.argmax(jnp.where(jnp.isnan(new_lls), -jnp.inf, new_lls)) |
| 98 | + return OptState(params=jtu.tree_map(lambda t: t[max_idx], new_theta), previous_params=t, count=count + 1) |
| 99 | + |
| 100 | + def cond_fun(state: OptState) -> bool: |
| 101 | + tn, tnm1, c = state |
| 102 | + should_stop = jnp.logical_or(jnp.all(jnp.array(tn) == jnp.array(tnm1)), c > max_iterations) |
| 103 | + # stop on NaN |
| 104 | + should_stop = jnp.logical_or(should_stop, jnp.any(jnp.isnan(jnp.array(tn)))) |
| 105 | + return jnp.logical_not(should_stop) |
| 106 | + |
| 107 | + opt_state = jax.lax.while_loop(cond_fun, body_fun, |
| 108 | + OptState(params=theta0, previous_params=jtu.tree_map(lambda _: jnp.inf, theta0), |
| 109 | + count=0)) |
| 110 | + return opt_state.params, opt_state |
| 111 | + |
0 commit comments