1+ import equinox as eqx
2+ import jax
3+ import jax .numpy as jnp
4+ import optimistix as optx
5+ from jaxtyping import Array , Float , Int , PRNGKeyArray
6+
7+ import gsd
8+ from gsd import GSDParams
9+ from gsd .gsd import vmin
10+
11+
12+ @jax .jit
13+ def vmax (mean : Array , N : Int ) -> Array :
14+ """
15+ Computes maximal variance for categorical distribution supported on Z[1,N]
16+ :param mean:
17+ :param N:
18+ :return:
19+ """
20+ return (mean - 1.0 ) * (N - mean )
21+
22+
23+ def _lagrange_log_probs (lagrage : tuple , dist : 'MaxEntropyGSD' ):
24+ lamda1 , lamdam , lamdas = lagrage
25+ lp = lamda1 + dist .support * lamdam + lamdas * dist .squred_diff - 1.0
26+ return lp
27+
28+
29+ def _implicit_log_probs (lagrage : tuple , d : 'MaxEntropyGSD' ):
30+ lp = _lagrange_log_probs (lagrage , d )
31+ p = jnp .exp (lp )
32+ return (jnp .sum (p ) - 1.0 , # jax.nn.logsumexp(lp),
33+ jnp .dot (p , d .support ) - d .mean ,
34+ # jax.nn.logsumexp(a=lp, b=d.support) - jnp.log(d.mean),
35+ jnp .dot (p , d .squred_diff ) - d .sigma ** 2 ,
36+ # jax.nn.logsumexp(a=lp, b=d.squred_diff) - 2 * jnp.log(d.sigma)
37+ )
38+
39+
40+ def _explicit_log_probs (dist : 'MaxEntropyGSD' ):
41+ solver = optx .Newton (rtol = 1e-8 , atol = 1e-8 , )
42+
43+ lgr = jax .tree_util .tree_map (jnp .asarray , (- 0.01 , - 0.01 , - 0.01 ))
44+ sol = optx .root_find (_implicit_log_probs , solver , lgr , args = dist ,
45+ max_steps = int (1e4 ), throw = False )
46+ return _lagrange_log_probs (sol .value , dist )
47+
48+
49+ class MaxEntropyGSD (eqx .Module ):
50+ r"""
51+ Maximum entropy distribution supported on `Z[1,N]`
52+
53+ This distribution is defined to fulfill the following conditions on $p_i$
54+
55+ * Maximize $H= -\sum_i p_i\log(p_i)$ wrt.
56+ * $\sum p_i=1$
57+ * $\sum i p_i= \mu$
58+ * $\sum (i-\mu)^2 p_i= \sigma^2$
59+
60+ :param mean: Expectation value of the distribution.
61+ :param sigma: Standard deviation of the distribution.
62+ :param N: Number of responses
63+
64+ """
65+ mean : Float [Array , "" ]
66+ sigma : Float [Array , "" ] # std
67+ N : int = eqx .field (static = True )
68+
69+
70+ def log_prob (self , x : Int [Array , "" ]):
71+ lp = _explicit_log_probs (self )
72+ return lp [x - 1 ]
73+
74+ def prob (self , x : Int [Array , "" ]):
75+ return jnp .exp (self .log_prob (x ))
76+
77+ @property
78+ def support (self ):
79+ return jnp .arange (1 , self .N + 1 )
80+
81+ @property
82+ def squred_diff (self ):
83+ return jnp .square ((self .support - self .mean ))
84+
85+ def stddev (self ):
86+ return jnp .sqrt (self .variance ())
87+
88+ def vmax (self ):
89+ return (self .mean - 1.0 ) * (self .N - self .mean )
90+
91+ def vmin (self ):
92+ return vmin (self .mean )
93+
94+ @property
95+ def all_log_probs (self ):
96+ lp = _explicit_log_probs (self )
97+ return lp
98+
99+ @jax .jit
100+ def entropy (self ):
101+ lp = self .all_log_probs
102+ return - jnp .dot (lp , jnp .exp (lp ))
103+
104+ def sample (self , key : PRNGKeyArray , axis = - 1 , shape = None ):
105+ lp = self .all_log_probs
106+ return jax .random .categorical (key , lp , axis , shape ) + self .support [0 ]
107+
108+ @staticmethod
109+ def from_gsd (theta :GSDParams , N :int ) -> 'MaxEntropyGSD' :
110+ """Created maxentropy from GSD parameters.
111+
112+ :param theta: Parameters of a GSD distribution.
113+ :param N: Support size
114+ :return: A distribution object
115+ """
116+ return MaxEntropyGSD (
117+ mean = gsd .mean (theta .psi , theta .rho ),
118+ sigma = jnp .sqrt (gsd .variance (theta .psi , theta .rho )),
119+ N = N
120+ )
121+
122+ MaxEntropyGSD .__init__ .__doc__ = """Creates a MaxEntropyGSD
123+
124+ :param mean: Expectation value of the distribution.
125+ :param sigma: Standard deviation of the distribution.
126+ :param N: Number of responses
127+
128+ .. note::
129+ An alternative way to construct this distribution is by use of
130+ :ref:`from_gsd`
131+
132+ """
0 commit comments