@@ -50,25 +50,25 @@ def test_kitty(data_dir):
5050
5151 kt .assigned_classes = {0 : "nature" , 3 : "shop/offices" , 4 : "sport" }
5252
53- topic = kt .predict ([ "test sentence" ])
53+ tn = kt .transform ([ 'beautiful sea in the ocean' ], labels = [ 'nature' , 'shop/offices' ])
5454
55- assert topic [ 0 ] in kt . assigned_classes . values ( )
55+ kt . predict ([ 'beautiful sea in the ocean' ], 5 )
5656
57- kt .pretty_print_word_classes ( )
57+ kt .predict_topic ([ 'beautiful sea in the ocean' ], 5 )
5858
59+ assert len (tn ) == 1
5960
60- def test_custom_embeddings (data_dir ):
6161
62- with open (data_dir + "/custom_embeddings/sample_text.txt" ) as filino :
63- training = filino .read ().splitlines ()
62+ def test_preprocessing ():
6463
65- embeddings = np . load ( data_dir + "/custom_embeddings/sample_embeddings.npy" )
64+ testing_data = [ " this is some documents \t " , " test " ]
6665
67- turkish_stopwords = nltk .corpus .stopwords .words ('turkish' )
66+ sp = WhiteSpacePreprocessing (testing_data , stopwords_language = "english" )
67+ preprocessed_documents , unpreprocessed_corpus , vocab = sp .preprocess ()
6868
69- kt = Kitty ()
70- kt . train ( training , custom_embeddings = embeddings , topics = 5 , epochs = 1 ,
71- stopwords_list = turkish_stopwords , hidden_sizes = ( 200 , 200 ))
69+ assert len ( preprocessed_documents ) == 2
70+ assert len ( unpreprocessed_corpus ) == 2
71+ assert len ( vocab ) >= 2
7272
7373
7474def test_validation_set (data_dir ):
@@ -81,10 +81,8 @@ def test_validation_set(data_dir):
8181 training_dataset = tp .fit (data [:100 ], data [:100 ])
8282 validation_dataset = tp .transform (data [100 :105 ], data [100 :105 ])
8383
84- ctm = CombinedTM (reduce_on_plateau = True , solver = 'sgd' , batch_size = 2 , bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 )
85- ctm .fit (training_dataset , validation_dataset = validation_dataset , patience = 5 , save_dir = data_dir + 'test_checkpoint' )
86-
87- assert os .path .exists (data_dir + "test_checkpoint" )
84+ ctm = ZeroShotTM (bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 , batch_size = 2 )
85+ ctm .fit (training_dataset , validation_dataset )
8886
8987
9088def test_training_all_classes_ctm (data_dir ):
@@ -96,45 +94,58 @@ def test_training_all_classes_ctm(data_dir):
9694
9795 training_dataset = tp .fit (data , data )
9896 ctm = ZeroShotTM (bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 , batch_size = 2 )
99- ctm .fit (training_dataset ) # run the model
97+ ctm .fit (training_dataset )
10098
101- testing_dataset = tp .transform (data )
102- predictions = ctm .get_doc_topic_distribution (testing_dataset , n_samples = 2 )
99+ assert len (ctm .get_topics ()) == 5
103100
104- assert len ( predictions ) == len ( testing_dataset )
101+ ctm . get_topic_lists ( 25 )
105102
106- topics = ctm .get_topic_lists (2 )
107- assert len (topics ) == 5
103+ thetas = ctm .get_doc_topic_distribution (training_dataset , n_samples = 5 )
108104
109- training_dataset = tp .fit (data , data )
110- ctm = CombinedTM (bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 , batch_size = 2 )
111- ctm .fit (training_dataset ) # run the model
105+ assert len (thetas ) == len (data )
106+
107+ predicted_topics = ctm .get_doc_topic_distribution (training_dataset , n_samples = 5 )
108+
109+ assert len (predicted_topics ) == len (data )
110+
111+ ctm = CTM (bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 , batch_size = 2 )
112+ ctm .fit (training_dataset )
113+
114+ assert len (ctm .get_topics ()) == 5
115+
116+ ctm .get_topic_lists (25 )
112117
113- topics = ctm .get_topic_lists (2 )
114- assert len (topics ) == 5
118+ thetas = ctm .get_doc_topic_distribution (training_dataset , n_samples = 5 )
115119
116- ctm = CombinedTM (bow_size = len (tp .vocab ), contextual_size = 512 , num_epochs = 1 , n_components = 5 ,loss_weights = {"beta" : 10 }, batch_size = 2 )
117- ctm .fit (training_dataset ) # run the model
118- assert ctm .weights == {"beta" : 10 }
120+ assert len (thetas ) == len (data )
119121
120- topics = ctm .get_topic_lists (2 )
121- assert len (topics ) == 5
122+ predicted_topics = ctm .get_doc_topic_distribution (training_dataset , n_samples = 5 )
122123
123- testing_dataset = tp .transform (data , data )
124- predictions = ctm .get_doc_topic_distribution (testing_dataset , n_samples = 2 )
124+ assert len (predicted_topics ) == len (data )
125125
126- assert len (predictions ) == len (testing_dataset )
127126
127+ def test_training_ctm_combined_labels (data_dir ):
128+
129+ with open (data_dir + '/gnews/GoogleNews.txt' ) as filino :
130+ data = filino .readlines ()
131+ with open (data_dir + '/gnews/GoogleNews_LABEL.txt' ) as filino :
132+ labels = filino .readlines ()
133+
134+ tp = TopicModelDataPreparation ("paraphrase-distilroberta-base-v2" )
135+
136+ training_dataset = tp .fit (data [:100 ], data [:100 ], labels = labels [:100 ])
137+
138+ ctm = CombinedTM (bow_size = len (tp .vocab ), contextual_size = 768 , num_epochs = 1 , n_components = 5 , batch_size = 2 ,
139+ label_size = len (set (labels [:100 ])))
140+ ctm .fit (training_dataset )
141+
142+ assert len (ctm .get_topics ()) == 5
128143
129- def test_preprocessing (data_dir ):
130- docs = [line .strip () for line in open (data_dir + "gnews/GoogleNews.txt" , 'r' ).readlines ()]
131- sp = WhiteSpacePreprocessing (docs , "english" )
132- prep_corpus , unprepr_corpus , vocab , retained_indices = sp .preprocess ()
144+ ctm .get_topic_lists (25 )
133145
134- assert len (prep_corpus ) == len (unprepr_corpus ) # prep docs must have the same size as the unprep docs
135- assert len (prep_corpus ) <= len (docs ) # preprocessed docs must be less than or equal the original docs
146+ thetas = ctm .get_doc_topic_distribution (training_dataset , n_samples = 5 )
136147
137- assert len (vocab ) <= sp . vocabulary_size # check vocabulary size
148+ assert len (thetas ) == len ( data [: 100 ])
138149
139150
140151
0 commit comments