1- """Module for utility functions for fitting BTMs"""
1+ """Module for utility functions for fitting BTMs. """
22
33import random
44from typing import Dict , Tuple , TypeVar
1212
1313@njit
1414def doc_unique_biterms (
15- doc_unique_words : np .ndarray , doc_unique_word_counts : np .ndarray
15+ doc_unique_words : np .ndarray ,
16+ doc_unique_word_counts : np .ndarray ,
1617) -> Dict [Tuple [int , int ], int ]:
1718 (n_max_unique_words ,) = doc_unique_words .shape
1819 biterm_counts = dict ()
@@ -43,7 +44,7 @@ def doc_unique_biterms(
4344
4445@njit
4546def nb_add_counter (dest : Dict [T , int ], source : Dict [T , int ]):
46- """Adds one counter dict to another in place with Numba"""
47+ """Adds one counter dict to another in place with Numba. """
4748 for key in source :
4849 if key in dest :
4950 dest [key ] += source [key ]
@@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
5354
5455@njit
5556def corpus_unique_biterms (
56- doc_unique_words : np .ndarray , doc_unique_word_counts : np .ndarray
57+ doc_unique_words : np .ndarray ,
58+ doc_unique_word_counts : np .ndarray ,
5759) -> Dict [Tuple [int , int ], int ]:
5860 n_documents , _ = doc_unique_words .shape
5961 biterm_counts = doc_unique_biterms (
60- doc_unique_words [0 ], doc_unique_word_counts [0 ]
62+ doc_unique_words [0 ],
63+ doc_unique_word_counts [0 ],
6164 )
6265 for i_doc in range (1 , n_documents ):
6366 doc_unique_words_i = doc_unique_words [i_doc ]
6467 doc_unique_word_counts_i = doc_unique_word_counts [i_doc ]
6568 doc_biterms = doc_unique_biterms (
66- doc_unique_words_i , doc_unique_word_counts_i
69+ doc_unique_words_i ,
70+ doc_unique_word_counts_i ,
6771 )
6872 nb_add_counter (biterm_counts , doc_biterms )
6973 return biterm_counts
7074
7175
7276@njit
7377def compute_biterm_set (
74- biterm_counts : Dict [Tuple [int , int ], int ]
78+ biterm_counts : Dict [Tuple [int , int ], int ],
7579) -> np .ndarray :
7680 return np .array (list (biterm_counts .keys ()))
7781
@@ -116,7 +120,12 @@ def add_biterm(
116120 topic_biterm_count : np .ndarray ,
117121) -> None :
118122 add_remove_biterm (
119- True , i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
123+ True ,
124+ i_biterm ,
125+ i_topic ,
126+ biterms ,
127+ topic_word_count ,
128+ topic_biterm_count ,
120129 )
121130
122131
@@ -129,7 +138,12 @@ def remove_biterm(
129138 topic_biterm_count : np .ndarray ,
130139) -> None :
131140 add_remove_biterm (
132- False , i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
141+ False ,
142+ i_biterm ,
143+ i_topic ,
144+ biterms ,
145+ topic_word_count ,
146+ topic_biterm_count ,
133147 )
134148
135149
@@ -147,7 +161,11 @@ def init_components(
147161 i_topic = random .randint (0 , n_components - 1 )
148162 biterm_topic_assignments [i_biterm ] = i_topic
149163 add_biterm (
150- i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
164+ i_biterm ,
165+ i_topic ,
166+ biterms ,
167+ topic_word_count ,
168+ topic_biterm_count ,
151169 )
152170 return biterm_topic_assignments , topic_word_count , topic_biterm_count
153171
@@ -448,7 +466,10 @@ def predict_docs(
448466 )
449467 biterms = doc_unique_biterms (words , word_counts )
450468 prob_topic_given_document (
451- pred , biterms , topic_distribution , topic_word_distribution
469+ pred ,
470+ biterms ,
471+ topic_distribution ,
472+ topic_word_distribution ,
452473 )
453474 predictions [i_doc , :] = pred
454475 return predictions
0 commit comments