@@ -266,22 +266,24 @@ def calc_topics_marg_probs(
266266 """
267267 if topic_id is not None :
268268 if isinstance (theta , ndarray ):
269- return theta [topic_id , :].sum ()
269+ return theta [topic_id , :].mean ()
270270 if isinstance (theta , DataFrame ):
271- return theta .iloc [topic_id , :].sum ()
271+ return theta .iloc [topic_id , :].mean ()
272272
273- return theta .sum (axis = 1 )
273+ return theta .mean (axis = 1 )
274274
275275
276276def calc_terms_marg_probs (
277- phi : Union [ndarray , DataFrame ], word_id : Optional [int ] = None
277+ phi : Union [ndarray , DataFrame ], pt : Union [ ndarray , Series ], word_id : Optional [int ] = None
278278) -> Union [ndarray , Series ]:
279279 """Calculate marginal terms probabilities.
280280
281281 Parameters
282282 ----------
283283 phi : Union[numpy.ndarray, pandas.DataFrame]
284284 Words vs topics matrix.
285+ pt: Union[numpy.ndarray, pandas.Series]
286+ Topics marginal probabilities.
285287 word_id: Optional[int]
286288 Word index.
287289
@@ -292,24 +294,22 @@ def calc_terms_marg_probs(
292294 """
293295 if word_id is not None :
294296 if isinstance (phi , ndarray ):
295- return phi [word_id , :]. sum ()
297+ return ( phi [word_id , :] * pt ). mean ()
296298 if isinstance (phi , DataFrame ):
297- return phi .iloc [word_id , :]. sum ()
299+ return ( phi .iloc [word_id , :] * pt ). mean ()
298300
299- return phi . sum (axis = 1 )
301+ return ( phi * pt ). mean (axis = 1 )
300302
301303
302- def get_salient_terms (terms_freqs : ndarray , phi : ndarray , theta : ndarray ) -> ndarray :
304+ def get_salient_terms (phi : ndarray , theta : ndarray ) -> ndarray :
303305 """Get salient terms.
304306
305307 Calculated as:
306- saliency(w) = frequency (w) * [sum_t p(t | w) * log(p(t | w)/p(t))],
308+ saliency(w) = p (w) * [sum_t p(t | w) * log(p(t | w)/p(t))],
307309 where ``w`` is a term index, ``t`` is a topic index.
308310
309311 Parameters
310312 ----------
311- terms_freqs : numpy.ndarray
312- Words frequencies.
313313 phi : numpy.ndarray
314314 Words vs topics matrix.
315315 theta : numpy.ndarray
@@ -328,7 +328,7 @@ def _p_tw(phi, w, t):
328328
329329 saliency = array (
330330 (
331- terms_freqs [w ]
331+ p_w [w ]
332332 * sum (
333333 (
334334 _p_tw (phi , w , t ) * log (_p_tw (phi , w , t ) / p_t [t ])
@@ -338,7 +338,7 @@ def _p_tw(phi, w, t):
338338 for w in range (phi .shape [0 ])
339339 )
340340 )
341- # saliency(term w) = frequency (w)
341+ # saliency(term w) = p (w)
342342 # * [sum_t p(t | w) * log(p(t | w)/p(t))] for topics t
343343 # p(t | w) = p(w | t) * p(t) / p(w)
344344 return saliency
0 commit comments