33from altair import LayerChart
44from tomotopy import LDAModel
55from src import tmplot as tm
6- from numpy import random , floating
6+ from numpy import random , floating , ndarray
77from ipywidgets import VBox
8- from pandas import Series
98
109
1110class TestTmplot (unittest .TestCase ):
12-
1311 def setUp (self ):
14- self .tomotopy_model = LDAModel .load (' tests/models/tomotopyLDA.model' )
15- with open (' tests/models/gensimLDA.model' , 'rb' ) as file :
12+ self .tomotopy_model = LDAModel .load (" tests/models/tomotopyLDA.model" )
13+ with open (" tests/models/gensimLDA.model" , "rb" ) as file :
1614 self .gensim_model = pkl .load (file )
17- with open (' tests/models/gensimLDA.corpus' , 'rb' ) as file :
15+ with open (" tests/models/gensimLDA.corpus" , "rb" ) as file :
1816 self .gensim_corpus = pkl .load (file )
19- with open (' tests/models/btm_big.pickle' , 'rb' ) as file :
17+ with open (" tests/models/btm_big.pickle" , "rb" ) as file :
2018 self .btm_model_big = pkl .load (file )
21- with open (' tests/models/btm_small.pickle' , 'rb' ) as file :
19+ with open (" tests/models/btm_small.pickle" , "rb" ) as file :
2220 self .btm_model_small = pkl .load (file )
2321
2422 self .phi = tm .get_phi (self .tomotopy_model )
@@ -69,66 +67,75 @@ def test_prepare_coords(self):
6967 topics_coords = tm .prepare_coords (self .btm_model_big )
7068 self .assertTupleEqual (topics_coords .shape , (self .btm_model_big .topics_num_ , 5 ))
7169 topics_coords = tm .prepare_coords (self .btm_model_small )
72- self .assertTupleEqual (topics_coords .shape , (self .btm_model_small .topics_num_ , 5 ))
70+ self .assertTupleEqual (
71+ topics_coords .shape , (self .btm_model_small .topics_num_ , 5 )
72+ )
7373
7474 def test_get_topics_scatter (self ):
7575 topics_dists = tm .get_topics_dist (self .phi )
76- methods = ['tsne' , 'sem' , 'mds' , 'lle' , 'ltsa' , 'isomap' ]
77- topics_scatters = list (map (
78- lambda method :
79- tm .get_topics_scatter (topics_dists , self .theta , method = method ),
80- methods
81- ))
76+ methods = ["tsne" , "sem" , "mds" , "lle" , "ltsa" , "isomap" ]
77+ topics_scatters = list (
78+ map (
79+ lambda method : tm .get_topics_scatter (
80+ topics_dists , self .theta , method = method
81+ ),
82+ methods ,
83+ )
84+ )
8285 for scatter in topics_scatters :
8386 self .assertTupleEqual (scatter .shape , (self .tomotopy_model .k , 4 ))
8487
8588 def test_get_topics_dist (self ):
8689 methods = ["klb" , "jsd" , "jef" , "hel" , "bhat" , "tv" , "jac" ]
8790 topics_dists = list (
88- map (
89- lambda method : tm .get_topics_dist (self .phi , method = method ),
90- methods )
91+ map (lambda method : tm .get_topics_dist (self .phi , method = method ), methods )
9192 )
9293 for dist in topics_dists :
9394 self .assertTupleEqual (
94- dist .shape ,
95- ( self . tomotopy_model . k , self . tomotopy_model . k ) )
95+ dist .shape , ( self . tomotopy_model . k , self . tomotopy_model . k )
96+ )
9697
9798 def test_calc_topics_marg_probs (self ):
9899 topic_marg_prob = tm .calc_topics_marg_probs (self .theta , 0 )
99100 self .assertIsInstance (topic_marg_prob , floating )
100101 self .assertGreater (topic_marg_prob , 0 )
101102 topics_marg_probs = tm .calc_topics_marg_probs (self .theta )
102- self .assertIsInstance (topics_marg_probs , Series )
103+ self .assertIsInstance (topics_marg_probs , ndarray )
103104 self .assertEqual (topics_marg_probs .size , self .tomotopy_model .k )
105+ self .assertEqual (topics_marg_probs .sum (), 1 )
104106
105107 def test_calc_terms_marg_probs (self ):
106108 term_marg_prob = tm .calc_terms_marg_probs (self .phi , 0 )
107109 self .assertIsInstance (term_marg_prob , floating )
108110 self .assertGreater (term_marg_prob , 0 )
109111 terms_marg_probs = tm .calc_terms_marg_probs (self .phi )
110- self .assertIsInstance (terms_marg_probs , Series )
112+ self .assertIsInstance (terms_marg_probs , ndarray )
111113 self .assertEqual (terms_marg_probs .size , self .phi .index .size )
112114
113115 def test_plot_scatter_topics (self ):
114116 topics_coords = tm .prepare_coords (self .tomotopy_model )
115117 chart = tm .plot_scatter_topics (
116- topics_coords , size_col = 'size' , label_col = 'label' )
118+ topics_coords , size_col = "size" , label_col = "label"
119+ )
117120 self .assertIsInstance (chart , LayerChart )
118121
119122 def test_get_stable_topics (self ):
120123 models = [
121- self .tomotopy_model , self .tomotopy_model , self .tomotopy_model ,
122- self .tomotopy_model ]
124+ self .tomotopy_model ,
125+ self .tomotopy_model ,
126+ self .tomotopy_model ,
127+ self .tomotopy_model ,
128+ ]
123129 closest_topics , dists = tm .get_closest_topics (models )
124130 dists = random .normal (0 , 0.10 , dists .shape ).__abs__ ()
125131 stable_topics , stable_dists = tm .get_stable_topics (
126- closest_topics , dists , norm = False )
132+ closest_topics , dists , norm = False
133+ )
127134
128135 self .assertTupleEqual (
129- closest_topics .shape , (self .tomotopy_model .k , len (models )))
130- self . assertTupleEqual (
131- dists .shape , (self .tomotopy_model .k , len (models )))
136+ closest_topics .shape , (self .tomotopy_model .k , len (models ))
137+ )
138+ self . assertTupleEqual ( dists .shape , (self .tomotopy_model .k , len (models )))
132139 self .assertLessEqual (stable_topics .shape [0 ], self .tomotopy_model .k )
133140 self .assertLessEqual (stable_dists .shape [0 ], self .tomotopy_model .k )
134141 self .assertGreaterEqual (stable_topics .shape [0 ], 0 )
@@ -138,9 +145,8 @@ def test_get_stable_topics(self):
138145
139146 def test_report (self ):
140147 report = tm .report (
141- self .tomotopy_model ,
142- docs = tm .get_docs (self .tomotopy_model ),
143- width = 250 )
148+ self .tomotopy_model , docs = tm .get_docs (self .tomotopy_model ), width = 250
149+ )
144150 self .assertIsInstance (report , VBox )
145151
146152 def test_entropy (self ):
@@ -149,6 +155,10 @@ def test_entropy(self):
149155 self .assertGreater (entropy , 0 )
150156 self .assertGreater (entropy2 , 0 )
151157
158+ def test_get_salient_terms (self ):
159+ saliency = tm .get_salient_terms (self .phi , self .theta )
160+ self .assertEqual (saliency .size , self .phi .shape [0 ])
161+
152162
153- if __name__ == ' __main__' :
163+ if __name__ == " __main__" :
154164 unittest .main ()
0 commit comments