Skip to content

Commit f959dce

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 97af8e3 commit f959dce

File tree

6 files changed

+67
-32
lines changed

6 files changed

+67
-32
lines changed

mess.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from tqdm import trange
1818

1919
from tweetopic._doc import init_doc_words
20-
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
21-
predict_doc, sparse_multinomial_logpdf,
22-
symmetric_dirichlet_logpdf,
23-
symmetric_dirichlet_multinomial_logpdf)
20+
from tweetopic.bayesian.dmm import (
21+
BayesianDMM,
22+
posterior_predictive,
23+
predict_doc,
24+
sparse_multinomial_logpdf,
25+
symmetric_dirichlet_logpdf,
26+
symmetric_dirichlet_multinomial_logpdf,
27+
)
2428
from tweetopic.bayesian.sampling import batch_data, sample_nuts
2529
from tweetopic.func import spread
2630

@@ -58,23 +62,26 @@ def logprior_fn(params):
5862

5963
def loglikelihood_fn(params, data):
6064
doc_likelihood = jax.vmap(
61-
partial(sparse_multinomial_logpdf, component=params["component"])
65+
partial(sparse_multinomial_logpdf, component=params["component"]),
6266
)
6367
return jnp.sum(
6468
doc_likelihood(
6569
unique_words=data["doc_unique_words"],
6670
unique_word_counts=data["doc_unique_word_counts"],
67-
)
71+
),
6872
)
6973

7074

7175
logdensity_fn(position)
7276

7377
logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
74-
params, data
78+
params,
79+
data,
7580
)
7681
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
77-
logprior_fn, loglikelihood_fn, data_size=n_documents
82+
logprior_fn,
83+
loglikelihood_fn,
84+
data_size=n_documents,
7885
)
7986
rng_key = jax.random.PRNGKey(0)
8087
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
@@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
8895
)
8996
position = dict(
9097
component=jnp.array(
91-
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
92-
)
98+
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
99+
),
93100
)
94101

95102
samples, states = sample_nuts(position, logdensity_fn)

tweetopic/_btm.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module for utility functions for fitting BTMs"""
1+
"""Module for utility functions for fitting BTMs."""
22

33
import random
44
from typing import Dict, Tuple, TypeVar
@@ -12,7 +12,8 @@
1212

1313
@njit
1414
def doc_unique_biterms(
15-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
15+
doc_unique_words: np.ndarray,
16+
doc_unique_word_counts: np.ndarray,
1617
) -> Dict[Tuple[int, int], int]:
1718
(n_max_unique_words,) = doc_unique_words.shape
1819
biterm_counts = dict()
@@ -43,7 +44,7 @@ def doc_unique_biterms(
4344

4445
@njit
4546
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
46-
"""Adds one counter dict to another in place with Numba"""
47+
"""Adds one counter dict to another in place with Numba."""
4748
for key in source:
4849
if key in dest:
4950
dest[key] += source[key]
@@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
5354

5455
@njit
5556
def corpus_unique_biterms(
56-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
57+
doc_unique_words: np.ndarray,
58+
doc_unique_word_counts: np.ndarray,
5759
) -> Dict[Tuple[int, int], int]:
5860
n_documents, _ = doc_unique_words.shape
5961
biterm_counts = doc_unique_biterms(
60-
doc_unique_words[0], doc_unique_word_counts[0]
62+
doc_unique_words[0],
63+
doc_unique_word_counts[0],
6164
)
6265
for i_doc in range(1, n_documents):
6366
doc_unique_words_i = doc_unique_words[i_doc]
6467
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
6568
doc_biterms = doc_unique_biterms(
66-
doc_unique_words_i, doc_unique_word_counts_i
69+
doc_unique_words_i,
70+
doc_unique_word_counts_i,
6771
)
6872
nb_add_counter(biterm_counts, doc_biterms)
6973
return biterm_counts
7074

7175

7276
@njit
7377
def compute_biterm_set(
74-
biterm_counts: Dict[Tuple[int, int], int]
78+
biterm_counts: Dict[Tuple[int, int], int],
7579
) -> np.ndarray:
7680
return np.array(list(biterm_counts.keys()))
7781

@@ -116,7 +120,12 @@ def add_biterm(
116120
topic_biterm_count: np.ndarray,
117121
) -> None:
118122
add_remove_biterm(
119-
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
123+
True,
124+
i_biterm,
125+
i_topic,
126+
biterms,
127+
topic_word_count,
128+
topic_biterm_count,
120129
)
121130

122131

@@ -129,7 +138,12 @@ def remove_biterm(
129138
topic_biterm_count: np.ndarray,
130139
) -> None:
131140
add_remove_biterm(
132-
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
141+
False,
142+
i_biterm,
143+
i_topic,
144+
biterms,
145+
topic_word_count,
146+
topic_biterm_count,
133147
)
134148

135149

@@ -147,7 +161,11 @@ def init_components(
147161
i_topic = random.randint(0, n_components - 1)
148162
biterm_topic_assignments[i_biterm] = i_topic
149163
add_biterm(
150-
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
164+
i_biterm,
165+
i_topic,
166+
biterms,
167+
topic_word_count,
168+
topic_biterm_count,
151169
)
152170
return biterm_topic_assignments, topic_word_count, topic_biterm_count
153171

@@ -448,7 +466,10 @@ def predict_docs(
448466
)
449467
biterms = doc_unique_biterms(words, word_counts)
450468
prob_topic_given_document(
451-
pred, biterms, topic_distribution, topic_word_distribution
469+
pred,
470+
biterms,
471+
topic_distribution,
472+
topic_word_distribution,
452473
)
453474
predictions[i_doc, :] = pred
454475
return predictions

tweetopic/_dmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
1+
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
2+
Model."""
3+
24
from __future__ import annotations
35

46
from math import exp, log

tweetopic/_doc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def init_doc_words(
1111
n_docs, _ = doc_term_matrix.shape
1212
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
1313
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
14-
np.uint32
14+
np.uint32,
1515
)
1616
for i_doc in range(n_docs):
1717
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore

tweetopic/btm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
import sklearn
1010
from numpy.typing import ArrayLike
1111

12-
from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms,
13-
fit_model, predict_docs)
12+
from tweetopic._btm import (
13+
compute_biterm_set,
14+
corpus_unique_biterms,
15+
fit_model,
16+
predict_docs,
17+
)
1418
from tweetopic._doc import init_doc_words
1519
from tweetopic.exceptions import NotFittedException
1620
from tweetopic.utils import set_numba_seed
1721

1822

1923
class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
20-
"""Implementation of the Biterm Topic Model with Gibbs Sampling
21-
solver.
24+
"""Implementation of the Biterm Topic Model with Gibbs Sampling solver.
2225
2326
Parameters
2427
----------
@@ -144,7 +147,9 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
144147
X.tolil(),
145148
max_unique_words=max_unique_words,
146149
)
147-
biterms = corpus_unique_biterms(doc_unique_words, doc_unique_word_counts)
150+
biterms = corpus_unique_biterms(
151+
doc_unique_words, doc_unique_word_counts
152+
)
148153
biterm_set = compute_biterm_set(biterms)
149154
self.topic_distribution, self.components_ = fit_model(
150155
n_iter=self.n_iterations,
@@ -159,8 +164,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
159164
# TODO: Something goes terribly wrong here, fix this
160165

161166
def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray:
162-
"""Predicts probabilities for each document belonging to each
163-
topic.
167+
"""Predicts probabilities for each document belonging to each topic.
164168
165169
Parameters
166170
----------

tweetopic/func.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Utility functions for use in the library."""
2+
23
from functools import wraps
34
from typing import Callable
45

56

67
def spread(fn: Callable):
7-
"""Creates a new function from the given function so that it takes one
8-
dict (PyTree) and spreads the arguments."""
8+
"""Creates a new function from the given function so that it takes one dict
9+
(PyTree) and spreads the arguments."""
910

1011
@wraps(fn)
1112
def inner(kwargs):

0 commit comments

Comments
 (0)