Skip to content

Commit 485a9d2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 327a529 commit 485a9d2

File tree

8 files changed

+126
-61
lines changed

8 files changed

+126
-61
lines changed

mess.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from tqdm import trange
1818

1919
from tweetopic._doc import init_doc_words
20-
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
21-
predict_doc, sparse_multinomial_logpdf,
22-
symmetric_dirichlet_logpdf,
23-
symmetric_dirichlet_multinomial_logpdf)
20+
from tweetopic.bayesian.dmm import (
21+
BayesianDMM,
22+
posterior_predictive,
23+
predict_doc,
24+
sparse_multinomial_logpdf,
25+
symmetric_dirichlet_logpdf,
26+
symmetric_dirichlet_multinomial_logpdf,
27+
)
2428
from tweetopic.bayesian.sampling import batch_data, sample_nuts
2529
from tweetopic.func import spread
2630

@@ -58,23 +62,26 @@ def logprior_fn(params):
5862

5963
def loglikelihood_fn(params, data):
6064
doc_likelihood = jax.vmap(
61-
partial(sparse_multinomial_logpdf, component=params["component"])
65+
partial(sparse_multinomial_logpdf, component=params["component"]),
6266
)
6367
return jnp.sum(
6468
doc_likelihood(
6569
unique_words=data["doc_unique_words"],
6670
unique_word_counts=data["doc_unique_word_counts"],
67-
)
71+
),
6872
)
6973

7074

7175
logdensity_fn(position)
7276

7377
logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
74-
params, data
78+
params,
79+
data,
7580
)
7681
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
77-
logprior_fn, loglikelihood_fn, data_size=n_documents
82+
logprior_fn,
83+
loglikelihood_fn,
84+
data_size=n_documents,
7885
)
7986
rng_key = jax.random.PRNGKey(0)
8087
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
@@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
8895
)
8996
position = dict(
9097
component=jnp.array(
91-
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
92-
)
98+
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
99+
),
93100
)
94101

95102
samples, states = sample_nuts(position, logdensity_fn)

tweetopic/_btm.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module for utility functions for fitting BTMs"""
1+
"""Module for utility functions for fitting BTMs."""
22

33
import random
44
from typing import Dict, Tuple, TypeVar
@@ -12,7 +12,8 @@
1212

1313
@njit
1414
def doc_unique_biterms(
15-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
15+
doc_unique_words: np.ndarray,
16+
doc_unique_word_counts: np.ndarray,
1617
) -> Dict[Tuple[int, int], int]:
1718
(n_max_unique_words,) = doc_unique_words.shape
1819
biterm_counts = dict()
@@ -43,7 +44,7 @@ def doc_unique_biterms(
4344

4445
@njit
4546
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
46-
"""Adds one counter dict to another in place with Numba"""
47+
"""Adds one counter dict to another in place with Numba."""
4748
for key in source:
4849
if key in dest:
4950
dest[key] += source[key]
@@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
5354

5455
@njit
5556
def corpus_unique_biterms(
56-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
57+
doc_unique_words: np.ndarray,
58+
doc_unique_word_counts: np.ndarray,
5759
) -> Dict[Tuple[int, int], int]:
5860
n_documents, _ = doc_unique_words.shape
5961
biterm_counts = doc_unique_biterms(
60-
doc_unique_words[0], doc_unique_word_counts[0]
62+
doc_unique_words[0],
63+
doc_unique_word_counts[0],
6164
)
6265
for i_doc in range(1, n_documents):
6366
doc_unique_words_i = doc_unique_words[i_doc]
6467
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
6568
doc_biterms = doc_unique_biterms(
66-
doc_unique_words_i, doc_unique_word_counts_i
69+
doc_unique_words_i,
70+
doc_unique_word_counts_i,
6771
)
6872
nb_add_counter(biterm_counts, doc_biterms)
6973
return biterm_counts
7074

7175

7276
@njit
7377
def compute_biterm_set(
74-
biterm_counts: Dict[Tuple[int, int], int]
78+
biterm_counts: Dict[Tuple[int, int], int],
7579
) -> np.ndarray:
7680
return np.array(list(biterm_counts.keys()))
7781

@@ -116,7 +120,12 @@ def add_biterm(
116120
topic_biterm_count: np.ndarray,
117121
) -> None:
118122
add_remove_biterm(
119-
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
123+
True,
124+
i_biterm,
125+
i_topic,
126+
biterms,
127+
topic_word_count,
128+
topic_biterm_count,
120129
)
121130

122131

@@ -129,7 +138,12 @@ def remove_biterm(
129138
topic_biterm_count: np.ndarray,
130139
) -> None:
131140
add_remove_biterm(
132-
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
141+
False,
142+
i_biterm,
143+
i_topic,
144+
biterms,
145+
topic_word_count,
146+
topic_biterm_count,
133147
)
134148

135149

@@ -147,7 +161,11 @@ def init_components(
147161
i_topic = random.randint(0, n_components - 1)
148162
biterm_topic_assignments[i_biterm] = i_topic
149163
add_biterm(
150-
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
164+
i_biterm,
165+
i_topic,
166+
biterms,
167+
topic_word_count,
168+
topic_biterm_count,
151169
)
152170
return biterm_topic_assignments, topic_word_count, topic_biterm_count
153171

@@ -448,7 +466,10 @@ def predict_docs(
448466
)
449467
biterms = doc_unique_biterms(words, word_counts)
450468
prob_topic_given_document(
451-
pred, biterms, topic_distribution, topic_word_distribution
469+
pred,
470+
biterms,
471+
topic_distribution,
472+
topic_word_distribution,
452473
)
453474
predictions[i_doc, :] = pred
454475
return predictions

tweetopic/_dmm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
1+
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
2+
Model."""
23
from __future__ import annotations
34

45
from math import exp, log

tweetopic/_doc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def init_doc_words(
1111
n_docs, _ = doc_term_matrix.shape
1212
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
1313
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
14-
np.uint32
14+
np.uint32,
1515
)
1616
for i_doc in range(n_docs):
1717
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore

tweetopic/bayesian/dmm.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
"""JAX implementation of probability densities and parameter initialization
2-
for the Dirichlet Multinomial Mixture Model."""
1+
"""JAX implementation of probability densities and parameter initialization for
2+
the Dirichlet Multinomial Mixture Model."""
33
from functools import partial
44

55
import jax
@@ -22,12 +22,18 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int):
2222

2323

2424
def init_parameters(
25-
n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float
25+
n_docs: int,
26+
n_vocab: int,
27+
n_components: int,
28+
alpha: float,
29+
beta: float,
2630
) -> dict:
2731
"""Initializes the parameters of the dmm to the mean of the prior."""
2832
return dict(
2933
weights=symmetric_dirichlet_multinomial_mean(
30-
alpha, n_docs, n_components
34+
alpha,
35+
n_docs,
36+
n_components,
3137
),
3238
components=np.broadcast_to(
3339
scipy.stats.dirichlet.mean(np.full(n_vocab, beta)),
@@ -41,13 +47,15 @@ def sparse_multinomial_logpdf(
4147
unique_words,
4248
unique_word_counts,
4349
):
44-
"""Calculates joint multinomial probability of a sparse representation"""
50+
"""Calculates joint multinomial probability of a sparse representation."""
4551
unique_word_counts = jnp.float64(unique_word_counts)
4652
n_words = jnp.sum(unique_word_counts)
4753
n_factorial = jax.lax.lgamma(n_words + 1)
4854
word_count_factorial = jax.lax.lgamma(unique_word_counts + 1)
4955
word_count_factorial = jnp.where(
50-
unique_word_counts != 0, word_count_factorial, 0
56+
unique_word_counts != 0,
57+
word_count_factorial,
58+
0,
5159
)
5260
denominator = jnp.sum(word_count_factorial)
5361
probs = component[unique_words]
@@ -84,18 +92,18 @@ def symmetric_dirichlet_multinomial_logpdf(x, n, alpha):
8492

8593

8694
def predict_doc(components, weights, unique_words, unique_word_counts):
87-
"""Predicts likelihood of a document belonging to
88-
each cluster based on given parameters."""
95+
"""Predicts likelihood of a document belonging to each cluster based on
96+
given parameters."""
8997
component_logpdf = partial(
9098
sparse_multinomial_logpdf,
9199
unique_words=unique_words,
92100
unique_word_counts=unique_word_counts,
93101
)
94102
component_logprobs = jax.lax.map(component_logpdf, components) + jnp.log(
95-
weights
103+
weights,
96104
)
97105
norm_probs = jnp.exp(
98-
component_logprobs - jax.scipy.special.logsumexp(component_logprobs)
106+
component_logprobs - jax.scipy.special.logsumexp(component_logprobs),
99107
)
100108
return norm_probs
101109

@@ -106,24 +114,31 @@ def predict_one(unique_words, unique_word_counts, components, weights):
106114
predict_doc,
107115
unique_words=unique_words,
108116
unique_word_counts=unique_word_counts,
109-
)
117+
),
110118
)(components, weights)
111119

112120

113121
def posterior_predictive(
114-
doc_unique_words, doc_unique_word_counts, components, weights
122+
doc_unique_words,
123+
doc_unique_word_counts,
124+
components,
125+
weights,
115126
):
116-
"""Predicts probability of a document belonging to each component
117-
for all posterior samples.
118-
"""
127+
"""Predicts probability of a document belonging to each component for all
128+
posterior samples."""
119129
predict_all = jax.vmap(
120-
partial(predict_one, components=components, weights=weights)
130+
partial(predict_one, components=components, weights=weights),
121131
)
122132
return predict_all(doc_unique_words, doc_unique_word_counts)
123133

124134

125135
def dmm_loglikelihood(
126-
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
136+
components,
137+
weights,
138+
doc_unique_words,
139+
doc_unique_word_counts,
140+
alpha,
141+
beta,
127142
):
128143
docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1)
129144

@@ -135,7 +150,8 @@ def doc_likelihood(doc):
135150
unique_word_counts=unique_word_counts,
136151
)
137152
component_logprobs = jax.lax.map(
138-
component_logpdf, components
153+
component_logpdf,
154+
components,
139155
) + jnp.log(weights)
140156
return jax.scipy.special.logsumexp(component_logprobs)
141157

@@ -146,17 +162,25 @@ def doc_likelihood(doc):
146162
def dmm_logprior(components, weights, alpha, beta, n_docs):
147163
components_prior = jnp.sum(
148164
jax.lax.map(
149-
partial(symmetric_dirichlet_logpdf, alpha=alpha), components
150-
)
165+
partial(symmetric_dirichlet_logpdf, alpha=alpha),
166+
components,
167+
),
151168
)
152169
weights_prior = symmetric_dirichlet_multinomial_logpdf(
153-
weights, n=jnp.float64(n_docs), alpha=beta
170+
weights,
171+
n=jnp.float64(n_docs),
172+
alpha=beta,
154173
)
155174
return components_prior + weights_prior
156175

157176

158177
def dmm_logpdf(
159-
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
178+
components,
179+
weights,
180+
doc_unique_words,
181+
doc_unique_word_counts,
182+
alpha,
183+
beta,
160184
):
161185
"""Calculates logdensity of the DMM at a given point in parameter space."""
162186
n_docs = doc_unique_words.shape[0]

0 commit comments

Comments
 (0)