Skip to content

Commit c3a5920

Browse files
committed
fixed errors in marginal distribution and saliency calculations
1 parent a2c9509 commit c3a5920

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

src/tmplot/_helpers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -266,22 +266,24 @@ def calc_topics_marg_probs(
266266
"""
267267
if topic_id is not None:
268268
if isinstance(theta, ndarray):
269-
return theta[topic_id, :].sum()
269+
return theta[topic_id, :].mean()
270270
if isinstance(theta, DataFrame):
271-
return theta.iloc[topic_id, :].sum()
271+
return theta.iloc[topic_id, :].mean()
272272

273-
return theta.sum(axis=1)
273+
return theta.mean(axis=1)
274274

275275

276276
def calc_terms_marg_probs(
277-
phi: Union[ndarray, DataFrame], word_id: Optional[int] = None
277+
phi: Union[ndarray, DataFrame], pt: Union[ndarray, Series], word_id: Optional[int] = None
278278
) -> Union[ndarray, Series]:
279279
"""Calculate marginal terms probabilities.
280280
281281
Parameters
282282
----------
283283
phi : Union[numpy.ndarray, pandas.DataFrame]
284284
Words vs topics matrix.
285+
pt: Union[numpy.ndarray, pandas.Series]
286+
Topics marginal probabilities.
285287
word_id: Optional[int]
286288
Word index.
287289
@@ -292,24 +294,22 @@ def calc_terms_marg_probs(
292294
"""
293295
if word_id is not None:
294296
if isinstance(phi, ndarray):
295-
return phi[word_id, :].sum()
297+
return (phi[word_id, :] * pt).mean()
296298
if isinstance(phi, DataFrame):
297-
return phi.iloc[word_id, :].sum()
299+
return (phi.iloc[word_id, :] * pt).mean()
298300

299-
return phi.sum(axis=1)
301+
return (phi * pt).mean(axis=1)
300302

301303

302-
def get_salient_terms(terms_freqs: ndarray, phi: ndarray, theta: ndarray) -> ndarray:
304+
def get_salient_terms(phi: ndarray, theta: ndarray) -> ndarray:
303305
"""Get salient terms.
304306
305307
Calculated as:
306-
saliency(w) = frequency(w) * [sum_t p(t | w) * log(p(t | w)/p(t))],
308+
saliency(w) = p(w) * [sum_t p(t | w) * log(p(t | w)/p(t))],
307309
where ``w`` is a term index, ``t`` is a topic index.
308310
309311
Parameters
310312
----------
311-
terms_freqs : numpy.ndarray
312-
Words frequencies.
313313
phi : numpy.ndarray
314314
Words vs topics matrix.
315315
theta : numpy.ndarray
@@ -328,7 +328,7 @@ def _p_tw(phi, w, t):
328328

329329
saliency = array(
330330
(
331-
terms_freqs[w]
331+
p_w[w]
332332
* sum(
333333
(
334334
_p_tw(phi, w, t) * log(_p_tw(phi, w, t) / p_t[t])
@@ -338,7 +338,7 @@ def _p_tw(phi, w, t):
338338
for w in range(phi.shape[0])
339339
)
340340
)
341-
# saliency(term w) = frequency(w)
341+
# saliency(term w) = p(w)
342342
# * [sum_t p(t | w) * log(p(t | w)/p(t))] for topics t
343343
# p(t | w) = p(w | t) * p(t) / p(w)
344344
return saliency

tests/test_tmplot.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from altair import LayerChart
44
from tomotopy import LDAModel
55
from src import tmplot as tm
6-
from numpy import random, floating
6+
from numpy import random, floating, array
77
from ipywidgets import VBox
88
from pandas import Series
99

@@ -100,15 +100,17 @@ def test_calc_topics_marg_probs(self):
100100
self.assertGreater(topic_marg_prob, 0)
101101
topics_marg_probs = tm.calc_topics_marg_probs(self.theta)
102102
self.assertIsInstance(topics_marg_probs, Series)
103+
self.assertTrue(np.isclose(topics_marg_probs.sum(), 1))
103104
self.assertEqual(topics_marg_probs.size, self.tomotopy_model.k)
104105

105106
def test_calc_terms_marg_probs(self):
106-
term_marg_prob = tm.calc_terms_marg_probs(self.phi, 0)
107+
term_marg_prob = tm.calc_terms_marg_probs(self.phi, tm.calc_topics_marg_probs(self.theta), 0)
107108
self.assertIsInstance(term_marg_prob, floating)
108109
self.assertGreater(term_marg_prob, 0)
109-
terms_marg_probs = tm.calc_terms_marg_probs(self.phi)
110+
terms_marg_probs = tm.calc_terms_marg_probs(self.phi, tm.calc_topics_marg_probs(self.theta))
110111
self.assertIsInstance(terms_marg_probs, Series)
111112
self.assertEqual(terms_marg_probs.size, self.phi.index.size)
113+
self.assertTrue(np.isclose(terms_marg_probs.sum(), 1))
112114

113115
def test_plot_scatter_topics(self):
114116
topics_coords = tm.prepare_coords(self.tomotopy_model)
@@ -149,6 +151,10 @@ def test_entropy(self):
149151
self.assertGreater(entropy, 0)
150152
self.assertGreater(entropy2, 0)
151153

154+
def test_get_salient_terms(self):
155+
saliency = tm.get_salient_terms(self.phi, self.theta)
156+
self.assertEqual(saliency.size, self.phi.shape[0])
157+
152158

153159
if __name__ == '__main__':
154160
unittest.main()

0 commit comments

Comments
 (0)