Skip to content

Commit cd98fc8

Browse files
author
Maarten Grootendorst
authored
v0.9.4 (#335)
* Expose diversity parameter * Improve stability of topic reduction * Added property to c-TF-IDF that all IDF values are positive (#351) * Improve stability of `.visualize_barchart()` and `.visualize_hierarchy()` * Major documentation overhaul (including #330) * Drop python 3.6 (#333) * Relax plotly dependency (#88) * Additional logging for `.transform` (#356)
1 parent 15ea0cd commit cd98fc8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+637
-439
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[![PyPI - Python](https://img.shields.io/badge/python-v3.6+-blue.svg)](https://pypi.org/project/bertopic/)
1+
[![PyPI - Python](https://img.shields.io/badge/python-v3.7+-blue.svg)](https://pypi.org/project/bertopic/)
22
[![Build](https://img.shields.io/github/workflow/status/MaartenGr/BERTopic/Code%20Checks/master)](https://pypi.org/project/bertopic/)
33
[![docs](https://img.shields.io/badge/docs-Passing-green.svg)](https://maartengr.github.io/BERTopic/)
44
[![PyPI - PyPi](https://img.shields.io/pypi/v/BERTopic)](https://pypi.org/project/bertopic/)
@@ -13,9 +13,9 @@ BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF
1313
allowing for easily interpretable topics whilst keeping important words in the topic descriptions.
1414

1515
BERTopic supports
16-
[**guided**](https://maartengr.github.io/BERTopic/tutorial/guided/guided.html),
17-
(semi-) [**supervised**](https://maartengr.github.io/BERTopic/tutorial/supervised/supervised.html),
18-
and [**dynamic**](https://maartengr.github.io/BERTopic/tutorial/topicsovertime/topicsovertime.html) topic modeling. It even supports visualizations similar to LDAvis!
16+
[**guided**](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html),
17+
(semi-) [**supervised**](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html),
18+
and [**dynamic**](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html) topic modeling. It even supports visualizations similar to LDAvis!
1919

2020
Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99)
2121
and [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4).
@@ -54,7 +54,7 @@ with one of the examples below:
5454

5555

5656
## Quick Start
57-
We start by extracting topics from the well-known 20 newsgroups dataset which is comprised of english documents:
57+
We start by extracting topics from the well-known 20 newsgroups dataset containing English documents:
5858

5959
```python
6060
from bertopic import BERTopic
@@ -66,7 +66,7 @@ topic_model = BERTopic()
6666
topics, probs = topic_model.fit_transform(docs)
6767
```
6868

69-
After generating topics, we can access the frequent topics that were generated:
69+
After generating topics and their probabilities, we can access the frequent topics that were generated:
7070

7171
```python
7272
>>> topic_model.get_topic_info()
@@ -123,7 +123,7 @@ topic_model.visualize_barchart()
123123

124124

125125
Find all possible visualizations with interactive examples in the documentation
126-
[here](https://maartengr.github.io/BERTopic/tutorial/visualization/visualization.html).
126+
[here](https://maartengr.github.io/BERTopic/getting_started/visualization/visualization.html).
127127

128128
## Embedding Models
129129
BERTopic supports many embedding models that can be used to embed the documents and words:
@@ -151,7 +151,7 @@ roberta = TransformerDocumentEmbeddings('roberta-base')
151151
topic_model = BERTopic(embedding_model=roberta)
152152
```
153153

154-
Click [here](https://maartengr.github.io/BERTopic/tutorial/embeddings/embeddings.html)
154+
Click [here](https://maartengr.github.io/BERTopic/getting_started/embeddings/embeddings.html)
155155
for a full overview of all supported embedding models.
156156

157157
## Dynamic Topic Modeling
@@ -238,7 +238,7 @@ To cite BERTopic in your work, please use the following bibtex reference:
238238
title = {BERTopic: Leveraging BERT and c-TF-IDF to create easily interpretable topics.},
239239
year = 2020,
240240
publisher = {Zenodo},
241-
version = {v0.9.2},
241+
version = {v0.9.4},
242242
doi = {10.5281/zenodo.4381785},
243243
url = {https://doi.org/10.5281/zenodo.4381785}
244244
}

bertopic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from bertopic._bertopic import BERTopic
22

3-
__version__ = "0.9.3"
3+
__version__ = "0.9.4"
44

55
__all__ = [
66
"BERTopic",

bertopic/_bertopic.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self,
7878
nr_topics: Union[int, str] = None,
7979
low_memory: bool = False,
8080
calculate_probabilities: bool = False,
81+
diversity: float = None,
8182
seed_topic_list: List[List[str]] = None,
8283
embedding_model=None,
8384
umap_model: UMAP = None,
@@ -105,8 +106,7 @@ def __init__(self,
105106
number of topics to the value specified. This reduction can take
106107
a while as each reduction in topics (-1) activates a c-TF-IDF
107108
calculation. If this is set to None, no reduction is applied. Use
108-
"auto" to automatically reduce topics that have a similarity of at
109-
least 0.9, do not maps all others.
109+
"auto" to automatically reduce topics using HDBSCAN.
110110
low_memory: Sets UMAP low memory to True to make sure less memory is used.
111111
calculate_probabilities: Whether to calculate the probabilities of all topics
112112
per document instead of the probability of the assigned
@@ -116,6 +116,9 @@ def __init__(self,
116116
you do not mind more computation time.
117117
NOTE: If false you cannot use the corresponding
118118
visualization method `visualize_probabilities`.
119+
diversity: Whether to use MMR to diversify the resulting topic representations.
120+
If set to None, MMR will not be used. Accepted values lie between
121+
0 and 1 with 0 being not at all diverse and 1 being very diverse.
119122
seed_topic_list: A list of seed words per topic to converge around
120123
verbose: Changes the verbosity of the model, Set to True if you want
121124
to track the stages of the model.
@@ -141,6 +144,7 @@ def __init__(self,
141144
self.nr_topics = nr_topics
142145
self.low_memory = low_memory
143146
self.calculate_probabilities = calculate_probabilities
147+
self.diversity = diversity
144148
self.verbose = verbose
145149
self.seed_topic_list = seed_topic_list
146150

@@ -370,10 +374,14 @@ def transform(self,
370374
verbose=self.verbose)
371375

372376
umap_embeddings = self.umap_model.transform(embeddings)
377+
logger.info("Reduced dimensionality with UMAP")
378+
373379
predictions, probabilities = hdbscan.approximate_predict(self.hdbscan_model, umap_embeddings)
380+
logger.info("Predicted clusters with HDBSCAN")
374381

375382
if self.calculate_probabilities:
376383
probabilities = hdbscan.membership_vector(self.hdbscan_model, umap_embeddings)
384+
logger.info("Calculated probabilities with HDBSCAN")
377385
else:
378386
probabilities = None
379387

@@ -476,7 +484,7 @@ def topics_over_time(self,
476484
selection = documents.loc[documents.Timestamps == timestamp, :]
477485
documents_per_topic = selection.groupby(['Topic'], as_index=False).agg({'Document': ' '.join,
478486
"Timestamps": "count"})
479-
c_tf_idf, words = self._c_tf_idf(documents_per_topic, m=len(selection), fit=False)
487+
c_tf_idf, words = self._c_tf_idf(documents_per_topic, fit=False)
480488

481489
if global_tuning or evolution_tuning:
482490
c_tf_idf = normalize(c_tf_idf, axis=1, norm='l1', copy=False)
@@ -569,7 +577,7 @@ def topics_per_class(self,
569577
selection = documents.loc[documents.Class == class_, :]
570578
documents_per_topic = selection.groupby(['Topic'], as_index=False).agg({'Document': ' '.join,
571579
"Class": "count"})
572-
c_tf_idf, words = self._c_tf_idf(documents_per_topic, m=len(selection), fit=False)
580+
c_tf_idf, words = self._c_tf_idf(documents_per_topic, fit=False)
573581

574582
# Fine-tune the timestamp c-TF-IDF representation based on the global c-TF-IDF representation
575583
# by simply taking the average of the two
@@ -1107,8 +1115,8 @@ def visualize_hierarchy(self,
11071115
Either 'left' or 'bottom'
11081116
topics: A selection of topics to visualize
11091117
top_n_topics: Only select the top n most frequent topics
1110-
width: The width of the figure.
1111-
height: The height of the figure.
1118+
width: The width of the figure. Only works if orientation is set to 'left'
1119+
height: The height of the figure. Only works if orientation is set to 'bottom'
11121120
11131121
Returns:
11141122
fig: A plotly figure
@@ -1185,18 +1193,18 @@ def visualize_heatmap(self,
11851193

11861194
def visualize_barchart(self,
11871195
topics: List[int] = None,
1188-
top_n_topics: int = 6,
1196+
top_n_topics: int = 8,
11891197
n_words: int = 5,
1190-
width: int = 800,
1191-
height: int = 600) -> go.Figure:
1198+
width: int = 250,
1199+
height: int = 250) -> go.Figure:
11921200
""" Visualize a barchart of selected topics
11931201
11941202
Arguments:
11951203
topics: A selection of topics to visualize.
11961204
top_n_topics: Only select the top n most frequent topics.
11971205
n_words: Number of words to show in a topic
1198-
width: The width of the figure.
1199-
height: The height of the figure.
1206+
width: The width of each figure.
1207+
height: The height of each figure.
12001208
12011209
Returns:
12021210
fig: A plotly figure
@@ -1447,7 +1455,7 @@ def _extract_topics(self, documents: pd.DataFrame):
14471455
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
14481456
"""
14491457
documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
1450-
self.c_tf_idf, words = self._c_tf_idf(documents_per_topic, m=len(documents))
1458+
self.c_tf_idf, words = self._c_tf_idf(documents_per_topic)
14511459
self.topics = self._extract_words_per_topic(words)
14521460
self._create_topic_vectors()
14531461
self.topic_names = {key: f"{key}_" + "_".join([word[0] for word in values[:4]])
@@ -1553,7 +1561,7 @@ def _create_topic_vectors(self):
15531561

15541562
self.topic_embeddings = topic_embeddings
15551563

1556-
def _c_tf_idf(self, documents_per_topic: pd.DataFrame, m: int, fit: bool = True) -> Tuple[csr_matrix, List[str]]:
1564+
def _c_tf_idf(self, documents_per_topic: pd.DataFrame, fit: bool = True) -> Tuple[csr_matrix, List[str]]:
15571565
""" Calculate a class-based TF-IDF where m is the number of total documents.
15581566
15591567
Arguments:
@@ -1581,7 +1589,7 @@ def _c_tf_idf(self, documents_per_topic: pd.DataFrame, m: int, fit: bool = True)
15811589
multiplier = None
15821590

15831591
if fit:
1584-
self.transformer = ClassTFIDF().fit(X, n_samples=m, multiplier=multiplier)
1592+
self.transformer = ClassTFIDF().fit(X, multiplier=multiplier)
15851593

15861594
c_tf_idf = self.transformer.transform(X)
15871595

@@ -1641,19 +1649,20 @@ def _extract_words_per_topic(self,
16411649

16421650
# Extract word embeddings for the top 30 words per topic and compare it
16431651
# with the topic embedding to keep only the words most similar to the topic embedding
1644-
if self.embedding_model is not None:
1652+
if self.diversity is not None:
1653+
if self.embedding_model is not None:
16451654

1646-
for topic, topic_words in topics.items():
1647-
words = [word[0] for word in topic_words]
1648-
word_embeddings = self._extract_embeddings(words,
1649-
method="word",
1650-
verbose=False)
1651-
topic_embedding = self._extract_embeddings(" ".join(words),
1652-
method="word",
1653-
verbose=False).reshape(1, -1)
1654-
topic_words = mmr(topic_embedding, word_embeddings, words,
1655-
top_n=self.top_n_words, diversity=0)
1656-
topics[topic] = [(word, value) for word, value in topics[topic] if word in topic_words]
1655+
for topic, topic_words in topics.items():
1656+
words = [word[0] for word in topic_words]
1657+
word_embeddings = self._extract_embeddings(words,
1658+
method="word",
1659+
verbose=False)
1660+
topic_embedding = self._extract_embeddings(" ".join(words),
1661+
method="word",
1662+
verbose=False).reshape(1, -1)
1663+
topic_words = mmr(topic_embedding, word_embeddings, words,
1664+
top_n=self.top_n_words, diversity=self.diversity)
1665+
topics[topic] = [(word, value) for word, value in topics[topic] if word in topic_words]
16571666
topics = {label: values[:self.top_n_words] for label, values in topics.items()}
16581667

16591668
return topics
@@ -1694,10 +1703,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame:
16941703
self.merged_topics = []
16951704

16961705
# Create topic similarity matrix
1697-
if self.topic_embeddings is not None:
1698-
similarities = cosine_similarity(np.array(self.topic_embeddings))
1699-
else:
1700-
similarities = cosine_similarity(self.c_tf_idf)
1706+
similarities = cosine_similarity(self.c_tf_idf)
17011707
np.fill_diagonal(similarities, 0)
17021708

17031709
# Find most similar topic to least common topic

bertopic/_ctfidf.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ class ClassTFIDF(TfidfTransformer):
2121
def __init__(self, *args, **kwargs):
2222
super(ClassTFIDF, self).__init__(*args, **kwargs)
2323

24-
def fit(self, X: sp.csr_matrix, n_samples: int, multiplier: np.ndarray = None):
24+
def fit(self, X: sp.csr_matrix, multiplier: np.ndarray = None):
2525
"""Learn the idf vector (global term weights).
2626
2727
Arguments:
2828
X: A matrix of term/token counts.
29-
n_samples: Number of total documents
29+
multiplier: A multiplier for increasing/decreasing certain IDF scores
3030
"""
3131
X = check_array(X, accept_sparse=('csr', 'csc'))
3232
if not sp.issparse(X):
@@ -35,19 +35,29 @@ def fit(self, X: sp.csr_matrix, n_samples: int, multiplier: np.ndarray = None):
3535

3636
if self.use_idf:
3737
_, n_features = X.shape
38+
39+
# Calculate the frequency of words across all classes
3840
df = np.squeeze(np.asarray(X.sum(axis=0)))
41+
42+
# Calculate the average number of samples as regularization
3943
avg_nr_samples = int(X.sum(axis=1).mean())
40-
idf = np.log(avg_nr_samples / df)
44+
45+
# Divide the average number of samples by the word frequency
46+
# +1 is added to force values to be positive
47+
idf = np.log((avg_nr_samples / df)+1)
48+
49+
# Multiplier to increase/decrease certain idf scores
4150
if multiplier is not None:
4251
idf = idf * multiplier
52+
4353
self._idf_diag = sp.diags(idf, offsets=0,
4454
shape=(n_features, n_features),
4555
format='csr',
4656
dtype=dtype)
4757

4858
return self
4959

50-
def transform(self, X: sp.csr_matrix, copy=True):
60+
def transform(self, X: sp.csr_matrix):
5161
"""Transform a count-based matrix to c-TF-IDF
5262
5363
Arguments:

bertopic/plotting/_barchart.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import numpy as np
23
from typing import List
34

@@ -7,19 +8,19 @@
78

89
def visualize_barchart(topic_model,
910
topics: List[int] = None,
10-
top_n_topics: int = 6,
11+
top_n_topics: int = 8,
1112
n_words: int = 5,
12-
width: int = 800,
13-
height: int = 600) -> go.Figure:
13+
width: int = 250,
14+
height: int = 250) -> go.Figure:
1415
""" Visualize a barchart of selected topics
1516
1617
Arguments:
1718
topic_model: A fitted BERTopic instance.
1819
topics: A selection of topics to visualize.
1920
top_n_topics: Only select the top n most frequent topics.
2021
n_words: Number of words to show in a topic
21-
width: The width of the figure.
22-
height: The height of the figure.
22+
width: The width of each figure.
23+
height: The height of each figure.
2324
2425
Returns:
2526
fig: A plotly figure
@@ -39,9 +40,11 @@ def visualize_barchart(topic_model,
3940
fig = topic_model.visualize_barchart()
4041
fig.write_html("path/to/file.html")
4142
```
42-
<iframe src="../../tutorial/visualization/bar_chart.html"
43+
<iframe src="../../getting_started/visualization/bar_chart.html"
4344
style="width:1100px; height: 660px; border: 0px;""></iframe>
4445
"""
46+
colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"])
47+
4548
# Select topics based on top_n and topics args
4649
if topics is not None:
4750
topics = list(topics)
@@ -52,13 +55,13 @@ def visualize_barchart(topic_model,
5255

5356
# Initialize figure
5457
subplot_titles = [f"Topic {topic}" for topic in topics]
55-
columns = 3
58+
columns = 4
5659
rows = int(np.ceil(len(topics) / columns))
5760
fig = make_subplots(rows=rows,
5861
cols=columns,
59-
shared_xaxes=True,
60-
horizontal_spacing=.15,
61-
vertical_spacing=.15,
62+
shared_xaxes=False,
63+
horizontal_spacing=.1,
64+
vertical_spacing=.4 / rows if rows > 1 else 0,
6265
subplot_titles=subplot_titles)
6366

6467
# Add barchart for each topic
@@ -71,7 +74,8 @@ def visualize_barchart(topic_model,
7174
fig.add_trace(
7275
go.Bar(x=scores,
7376
y=words,
74-
orientation='h'),
77+
orientation='h',
78+
marker_color=next(colors)),
7579
row=row, col=column)
7680

7781
if column == columns:
@@ -86,16 +90,15 @@ def visualize_barchart(topic_model,
8690
showlegend=False,
8791
title={
8892
'text': "<b>Topic Word Scores",
89-
'y': .95,
90-
'x': .15,
93+
'x': .5,
9194
'xanchor': 'center',
9295
'yanchor': 'top',
9396
'font': dict(
9497
size=22,
9598
color="Black")
9699
},
97-
width=width,
98-
height=height,
100+
width=width*4,
101+
height=height*rows if rows > 1 else height * 1.3,
99102
hoverlabel=dict(
100103
bgcolor="white",
101104
font_size=16,

bertopic/plotting/_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def visualize_distribution(topic_model,
3232
fig = topic_model.visualize_distribution(probabilities[0])
3333
fig.write_html("path/to/file.html")
3434
```
35-
<iframe src="../../tutorial/visualization/probabilities.html"
35+
<iframe src="../../getting_started/visualization/probabilities.html"
3636
style="width:1000px; height: 500px; border: 0px;""></iframe>
3737
"""
3838
if len(probabilities.shape) != 1:

0 commit comments

Comments
 (0)