Skip to content

Commit 71c8381

Browse files
committed
improved marginal distribution and saliency calculation
1 parent a2c9509 commit 71c8381

File tree

2 files changed

+59
-54
lines changed

2 files changed

+59
-54
lines changed

src/tmplot/_helpers.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _select_docs(docs, theta, topic_id: int):
249249

250250
def calc_topics_marg_probs(
251251
theta: Union[DataFrame, ndarray], topic_id: Optional[int] = None
252-
) -> Union[DataFrame, ndarray]:
252+
) -> ndarray:
253253
"""Calculate marginal topics probabilities.
254254
255255
Parameters
@@ -264,18 +264,18 @@ def calc_topics_marg_probs(
264264
Union[pandas.DataFrame, numpy.ndarray]
265265
Marginal topics probabilities.
266266
"""
267+
p_t = array(theta).sum(axis=1)
268+
p_t /= p_t.sum()
267269
if topic_id is not None:
268-
if isinstance(theta, ndarray):
269-
return theta[topic_id, :].sum()
270-
if isinstance(theta, DataFrame):
271-
return theta.iloc[topic_id, :].sum()
272-
273-
return theta.sum(axis=1)
270+
return p_t[topic_id]
271+
return p_t
274272

275273

276274
def calc_terms_marg_probs(
277-
phi: Union[ndarray, DataFrame], word_id: Optional[int] = None
278-
) -> Union[ndarray, Series]:
275+
phi: Union[ndarray, DataFrame],
276+
p_t: Union[ndarray, Series],
277+
word_id: Optional[int] = None,
278+
) -> ndarray:
279279
"""Calculate marginal terms probabilities.
280280
281281
Parameters
@@ -290,16 +290,13 @@ def calc_terms_marg_probs(
290290
Union[numpy.ndarray, pandas.Series]
291291
Marginal terms probabilities.
292292
"""
293+
p_w = (array(phi) * array(p_t)).sum(axis=1)
293294
if word_id is not None:
294-
if isinstance(phi, ndarray):
295-
return phi[word_id, :].sum()
296-
if isinstance(phi, DataFrame):
297-
return phi.iloc[word_id, :].sum()
298-
299-
return phi.sum(axis=1)
295+
return p_w[word_id]
296+
return p_w
300297

301298

302-
def get_salient_terms(terms_freqs: ndarray, phi: ndarray, theta: ndarray) -> ndarray:
299+
def get_salient_terms(phi: ndarray, theta: ndarray) -> ndarray:
303300
"""Get salient terms.
304301
305302
Calculated as:
@@ -308,8 +305,6 @@ def get_salient_terms(terms_freqs: ndarray, phi: ndarray, theta: ndarray) -> nda
308305
309306
Parameters
310307
----------
311-
terms_freqs : numpy.ndarray
312-
Words frequencies.
313308
phi : numpy.ndarray
314309
Words vs topics matrix.
315310
theta : numpy.ndarray
@@ -320,15 +315,15 @@ def get_salient_terms(terms_freqs: ndarray, phi: ndarray, theta: ndarray) -> nda
320315
numpy.ndarray
321316
Terms saliency values.
322317
"""
323-
p_t = array(calc_topics_marg_probs(theta))
324-
p_w = array(calc_terms_marg_probs(phi))
318+
p_t = calc_topics_marg_probs(theta)
319+
p_w = calc_terms_marg_probs(phi, p_t)
325320

326321
def _p_tw(phi, w, t):
327322
return phi[w, t] * p_t[t] / p_w[w]
328323

329324
saliency = array(
330325
(
331-
terms_freqs[w]
326+
p_w[w]
332327
* sum(
333328
(
334329
_p_tw(phi, w, t) * log(_p_tw(phi, w, t) / p_t[t])

tests/test_tmplot.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
from altair import LayerChart
44
from tomotopy import LDAModel
55
from src import tmplot as tm
6-
from numpy import random, floating
6+
from numpy import random, floating, ndarray
77
from ipywidgets import VBox
8-
from pandas import Series
98

109

1110
class TestTmplot(unittest.TestCase):
12-
1311
def setUp(self):
14-
self.tomotopy_model = LDAModel.load('tests/models/tomotopyLDA.model')
15-
with open('tests/models/gensimLDA.model', 'rb') as file:
12+
self.tomotopy_model = LDAModel.load("tests/models/tomotopyLDA.model")
13+
with open("tests/models/gensimLDA.model", "rb") as file:
1614
self.gensim_model = pkl.load(file)
17-
with open('tests/models/gensimLDA.corpus', 'rb') as file:
15+
with open("tests/models/gensimLDA.corpus", "rb") as file:
1816
self.gensim_corpus = pkl.load(file)
19-
with open('tests/models/btm_big.pickle', 'rb') as file:
17+
with open("tests/models/btm_big.pickle", "rb") as file:
2018
self.btm_model_big = pkl.load(file)
21-
with open('tests/models/btm_small.pickle', 'rb') as file:
19+
with open("tests/models/btm_small.pickle", "rb") as file:
2220
self.btm_model_small = pkl.load(file)
2321

2422
self.phi = tm.get_phi(self.tomotopy_model)
@@ -69,66 +67,75 @@ def test_prepare_coords(self):
6967
topics_coords = tm.prepare_coords(self.btm_model_big)
7068
self.assertTupleEqual(topics_coords.shape, (self.btm_model_big.topics_num_, 5))
7169
topics_coords = tm.prepare_coords(self.btm_model_small)
72-
self.assertTupleEqual(topics_coords.shape, (self.btm_model_small.topics_num_, 5))
70+
self.assertTupleEqual(
71+
topics_coords.shape, (self.btm_model_small.topics_num_, 5)
72+
)
7373

7474
def test_get_topics_scatter(self):
7575
topics_dists = tm.get_topics_dist(self.phi)
76-
methods = ['tsne', 'sem', 'mds', 'lle', 'ltsa', 'isomap']
77-
topics_scatters = list(map(
78-
lambda method:
79-
tm.get_topics_scatter(topics_dists, self.theta, method=method),
80-
methods
81-
))
76+
methods = ["tsne", "sem", "mds", "lle", "ltsa", "isomap"]
77+
topics_scatters = list(
78+
map(
79+
lambda method: tm.get_topics_scatter(
80+
topics_dists, self.theta, method=method
81+
),
82+
methods,
83+
)
84+
)
8285
for scatter in topics_scatters:
8386
self.assertTupleEqual(scatter.shape, (self.tomotopy_model.k, 4))
8487

8588
def test_get_topics_dist(self):
8689
methods = ["klb", "jsd", "jef", "hel", "bhat", "tv", "jac"]
8790
topics_dists = list(
88-
map(
89-
lambda method: tm.get_topics_dist(self.phi, method=method),
90-
methods)
91+
map(lambda method: tm.get_topics_dist(self.phi, method=method), methods)
9192
)
9293
for dist in topics_dists:
9394
self.assertTupleEqual(
94-
dist.shape,
95-
(self.tomotopy_model.k, self.tomotopy_model.k))
95+
dist.shape, (self.tomotopy_model.k, self.tomotopy_model.k)
96+
)
9697

9798
def test_calc_topics_marg_probs(self):
9899
topic_marg_prob = tm.calc_topics_marg_probs(self.theta, 0)
99100
self.assertIsInstance(topic_marg_prob, floating)
100101
self.assertGreater(topic_marg_prob, 0)
101102
topics_marg_probs = tm.calc_topics_marg_probs(self.theta)
102-
self.assertIsInstance(topics_marg_probs, Series)
103+
self.assertIsInstance(topics_marg_probs, ndarray)
103104
self.assertEqual(topics_marg_probs.size, self.tomotopy_model.k)
105+
self.assertEqual(topics_marg_probs.sum(), 1)
104106

105107
def test_calc_terms_marg_probs(self):
106108
term_marg_prob = tm.calc_terms_marg_probs(self.phi, 0)
107109
self.assertIsInstance(term_marg_prob, floating)
108110
self.assertGreater(term_marg_prob, 0)
109111
terms_marg_probs = tm.calc_terms_marg_probs(self.phi)
110-
self.assertIsInstance(terms_marg_probs, Series)
112+
self.assertIsInstance(terms_marg_probs, ndarray)
111113
self.assertEqual(terms_marg_probs.size, self.phi.index.size)
112114

113115
def test_plot_scatter_topics(self):
114116
topics_coords = tm.prepare_coords(self.tomotopy_model)
115117
chart = tm.plot_scatter_topics(
116-
topics_coords, size_col='size', label_col='label')
118+
topics_coords, size_col="size", label_col="label"
119+
)
117120
self.assertIsInstance(chart, LayerChart)
118121

119122
def test_get_stable_topics(self):
120123
models = [
121-
self.tomotopy_model, self.tomotopy_model, self.tomotopy_model,
122-
self.tomotopy_model]
124+
self.tomotopy_model,
125+
self.tomotopy_model,
126+
self.tomotopy_model,
127+
self.tomotopy_model,
128+
]
123129
closest_topics, dists = tm.get_closest_topics(models)
124130
dists = random.normal(0, 0.10, dists.shape).__abs__()
125131
stable_topics, stable_dists = tm.get_stable_topics(
126-
closest_topics, dists, norm=False)
132+
closest_topics, dists, norm=False
133+
)
127134

128135
self.assertTupleEqual(
129-
closest_topics.shape, (self.tomotopy_model.k, len(models)))
130-
self.assertTupleEqual(
131-
dists.shape, (self.tomotopy_model.k, len(models)))
136+
closest_topics.shape, (self.tomotopy_model.k, len(models))
137+
)
138+
self.assertTupleEqual(dists.shape, (self.tomotopy_model.k, len(models)))
132139
self.assertLessEqual(stable_topics.shape[0], self.tomotopy_model.k)
133140
self.assertLessEqual(stable_dists.shape[0], self.tomotopy_model.k)
134141
self.assertGreaterEqual(stable_topics.shape[0], 0)
@@ -138,9 +145,8 @@ def test_get_stable_topics(self):
138145

139146
def test_report(self):
140147
report = tm.report(
141-
self.tomotopy_model,
142-
docs=tm.get_docs(self.tomotopy_model),
143-
width=250)
148+
self.tomotopy_model, docs=tm.get_docs(self.tomotopy_model), width=250
149+
)
144150
self.assertIsInstance(report, VBox)
145151

146152
def test_entropy(self):
@@ -149,6 +155,10 @@ def test_entropy(self):
149155
self.assertGreater(entropy, 0)
150156
self.assertGreater(entropy2, 0)
151157

158+
def test_get_salient_terms(self):
159+
saliency = tm.get_salient_terms(self.phi, self.theta)
160+
self.assertEqual(saliency.size, self.phi.shape[0])
161+
152162

153-
if __name__ == '__main__':
163+
if __name__ == "__main__":
154164
unittest.main()

0 commit comments

Comments
 (0)