1- """JAX implementation of probability densities and parameter initialization
2- for the Dirichlet Multinomial Mixture Model."""
1+ """JAX implementation of probability densities and parameter initialization for
2+ the Dirichlet Multinomial Mixture Model."""
33from functools import partial
44
55import jax
@@ -22,12 +22,18 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int):
2222
2323
2424def init_parameters (
25- n_docs : int , n_vocab : int , n_components : int , alpha : float , beta : float
25+ n_docs : int ,
26+ n_vocab : int ,
27+ n_components : int ,
28+ alpha : float ,
29+ beta : float ,
2630) -> dict :
2731 """Initializes the parameters of the dmm to the mean of the prior."""
2832 return dict (
2933 weights = symmetric_dirichlet_multinomial_mean (
30- alpha , n_docs , n_components
34+ alpha ,
35+ n_docs ,
36+ n_components ,
3137 ),
3238 components = np .broadcast_to (
3339 scipy .stats .dirichlet .mean (np .full (n_vocab , beta )),
@@ -41,13 +47,15 @@ def sparse_multinomial_logpdf(
4147 unique_words ,
4248 unique_word_counts ,
4349):
44- """Calculates joint multinomial probability of a sparse representation"""
50+ """Calculates joint multinomial probability of a sparse representation. """
4551 unique_word_counts = jnp .float64 (unique_word_counts )
4652 n_words = jnp .sum (unique_word_counts )
4753 n_factorial = jax .lax .lgamma (n_words + 1 )
4854 word_count_factorial = jax .lax .lgamma (unique_word_counts + 1 )
4955 word_count_factorial = jnp .where (
50- unique_word_counts != 0 , word_count_factorial , 0
56+ unique_word_counts != 0 ,
57+ word_count_factorial ,
58+ 0 ,
5159 )
5260 denominator = jnp .sum (word_count_factorial )
5361 probs = component [unique_words ]
@@ -84,18 +92,18 @@ def symmetric_dirichlet_multinomial_logpdf(x, n, alpha):
8492
8593
8694def predict_doc (components , weights , unique_words , unique_word_counts ):
87- """Predicts likelihood of a document belonging to
88- each cluster based on given parameters."""
95+ """Predicts likelihood of a document belonging to each cluster based on
96+ given parameters."""
8997 component_logpdf = partial (
9098 sparse_multinomial_logpdf ,
9199 unique_words = unique_words ,
92100 unique_word_counts = unique_word_counts ,
93101 )
94102 component_logprobs = jax .lax .map (component_logpdf , components ) + jnp .log (
95- weights
103+ weights ,
96104 )
97105 norm_probs = jnp .exp (
98- component_logprobs - jax .scipy .special .logsumexp (component_logprobs )
106+ component_logprobs - jax .scipy .special .logsumexp (component_logprobs ),
99107 )
100108 return norm_probs
101109
@@ -106,24 +114,31 @@ def predict_one(unique_words, unique_word_counts, components, weights):
106114 predict_doc ,
107115 unique_words = unique_words ,
108116 unique_word_counts = unique_word_counts ,
109- )
117+ ),
110118 )(components , weights )
111119
112120
113121def posterior_predictive (
114- doc_unique_words , doc_unique_word_counts , components , weights
122+ doc_unique_words ,
123+ doc_unique_word_counts ,
124+ components ,
125+ weights ,
115126):
116- """Predicts probability of a document belonging to each component
117- for all posterior samples.
118- """
127+ """Predicts probability of a document belonging to each component for all
128+ posterior samples."""
119129 predict_all = jax .vmap (
120- partial (predict_one , components = components , weights = weights )
130+ partial (predict_one , components = components , weights = weights ),
121131 )
122132 return predict_all (doc_unique_words , doc_unique_word_counts )
123133
124134
125135def dmm_loglikelihood (
126- components , weights , doc_unique_words , doc_unique_word_counts , alpha , beta
136+ components ,
137+ weights ,
138+ doc_unique_words ,
139+ doc_unique_word_counts ,
140+ alpha ,
141+ beta ,
127142):
128143 docs = jnp .stack ((doc_unique_words , doc_unique_word_counts ), axis = 1 )
129144
@@ -135,7 +150,8 @@ def doc_likelihood(doc):
135150 unique_word_counts = unique_word_counts ,
136151 )
137152 component_logprobs = jax .lax .map (
138- component_logpdf , components
153+ component_logpdf ,
154+ components ,
139155 ) + jnp .log (weights )
140156 return jax .scipy .special .logsumexp (component_logprobs )
141157
@@ -146,17 +162,25 @@ def doc_likelihood(doc):
146162def dmm_logprior (components , weights , alpha , beta , n_docs ):
147163 components_prior = jnp .sum (
148164 jax .lax .map (
149- partial (symmetric_dirichlet_logpdf , alpha = alpha ), components
150- )
165+ partial (symmetric_dirichlet_logpdf , alpha = alpha ),
166+ components ,
167+ ),
151168 )
152169 weights_prior = symmetric_dirichlet_multinomial_logpdf (
153- weights , n = jnp .float64 (n_docs ), alpha = beta
170+ weights ,
171+ n = jnp .float64 (n_docs ),
172+ alpha = beta ,
154173 )
155174 return components_prior + weights_prior
156175
157176
158177def dmm_logpdf (
159- components , weights , doc_unique_words , doc_unique_word_counts , alpha , beta
178+ components ,
179+ weights ,
180+ doc_unique_words ,
181+ doc_unique_word_counts ,
182+ alpha ,
183+ beta ,
160184):
161185 """Calculates logdensity of the DMM at a given point in parameter space."""
162186 n_docs = doc_unique_words .shape [0 ]
0 commit comments