-
Notifications
You must be signed in to change notification settings - Fork 885
Expand file tree
/
Copy pathtest_bertopic.py
More file actions
164 lines (127 loc) · 5.79 KB
/
test_bertopic.py
File metadata and controls
164 lines (127 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import copy
import pytest
from bertopic import BERTopic
def cuml_available():
try:
import cuml
return True
except ImportError:
return False
@pytest.mark.parametrize(
'model',
[
("base_topic_model"),
('kmeans_pca_topic_model'),
('custom_topic_model'),
('merged_topic_model'),
('reduced_topic_model'),
('online_topic_model'),
('supervised_topic_model'),
('representation_topic_model'),
('zeroshot_topic_model'),
pytest.param(
"cuml_base_topic_model", marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available")
),
])
def test_full_model(model, documents, request):
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default
settings result in a good separation of topics.
NOTE: This does not cover all cases but merely combines it all together
"""
topic_model = copy.deepcopy(request.getfixturevalue(model))
if model == "base_topic_model":
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
topic_model = BERTopic.load("model_dir")
if model == "cuml_base_topic_model":
assert "cuml" in str(type(topic_model.umap_model)).lower()
assert "cuml" in str(type(topic_model.hdbscan_model)).lower()
topics = topic_model.topics_
for topic in set(topics):
words = topic_model.get_topic(topic)[:10]
assert len(words) == 10
for topic in topic_model.get_topic_freq().Topic:
words = topic_model.get_topic(topic)[:10]
assert len(words) == 10
assert len(topic_model.get_topic_freq()) > 2
assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq())
# Test extraction of document info
document_info = topic_model.get_document_info(documents)
assert len(document_info) == len(documents)
# Test transform
doc = "This is a new document to predict."
topics_test, probs_test = topic_model.transform([doc, doc])
assert len(topics_test) == 2
# Test topics over time
timestamps = [i % 10 for i in range(len(documents))]
topics_over_time = topic_model.topics_over_time(documents, timestamps)
assert topics_over_time.Frequency.sum() == len(documents)
assert len(topics_over_time.Topic.unique()) == len(set(topics))
# Test hierarchical topics
hier_topics = topic_model.hierarchical_topics(documents)
assert len(hier_topics) > 0
assert hier_topics.Parent_ID.astype(int).min() > max(topics)
# Test creation of topic tree
tree = topic_model.get_topic_tree(hier_topics, tight_layout=False)
assert isinstance(tree, str)
assert len(tree) > 10
# Test find topic
similar_topics, similarity = topic_model.find_topics("query", top_n=2)
assert len(similar_topics) == 2
assert len(similarity) == 2
assert max(similarity) <= 1
# Test topic reduction
nr_topics = len(set(topics))
nr_topics = 2 if nr_topics < 2 else nr_topics - 1
topic_model.reduce_topics(documents, nr_topics=nr_topics)
assert len(topic_model.get_topic_freq()) == nr_topics
assert len(topic_model.topics_) == len(topics)
# Test update topics
topic = topic_model.get_topic(1)[:10]
vectorizer_model = topic_model.vectorizer_model
topic_model.update_topics(documents, n_gram_range=(2, 2))
updated_topic = topic_model.get_topic(1)[:10]
topic_model.update_topics(documents, vectorizer_model=vectorizer_model)
original_topic = topic_model.get_topic(1)[:10]
assert topic != updated_topic
if topic_model.representation_model is not None:
assert topic != original_topic
# Test updating topic labels
topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ")
assert len(topic_labels) == len(set(topic_model.topics_))
# Test setting topic labels
topic_model.set_topic_labels(topic_labels)
assert topic_model.custom_labels_ == topic_labels
# Test merging topics
freq = topic_model.get_topic_freq(0)
topics_to_merge = [0, 1]
topic_model.merge_topics(documents, topics_to_merge)
assert freq < topic_model.get_topic_freq(0)
# Test reduction of outliers
if -1 in topics:
new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0)
nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1])
nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1])
if topic_model._outliers == 1:
assert nr_outliers_topic_model > nr_outliers_new_topics
# Combine models
topic_model1 = BERTopic.load("model_dir")
merged_model = BERTopic.merge_models([topic_model, topic_model1])
assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info())
assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())
def test_transform_flexibility(documents, document_embeddings, request):
topic_model = copy.deepcopy(request.getfixturevalue('base_topic_model'))
print(document_embeddings[0].shape)
try:
topic_model.transform(documents[0], document_embeddings[0])
except ValueError:
pytest.fail('Error thrown for transform with single document and embeddings')
try:
topic_model.transform(documents[0:2], document_embeddings[0:2])
except ValueError:
pytest.fail('Error thrown for transform with multiple documents and embeddings')
with pytest.raises(ValueError):
topic_model.transform(documents[0], document_embeddings[0:2])
with pytest.raises(ValueError):
topic_model.transform(documents[0:2], document_embeddings[0])
with pytest.raises(ValueError):
topic_model.transform(documents[0], [1, 2, 3])