Skip to content

Commit 980d14e

Browse files
authored
Add Model2Vec as an embedding backend (#2245)
1 parent 998ac83 commit 980d14e

File tree

5 files changed

+240
-15
lines changed

5 files changed

+240
-15
lines changed

bertopic/backend/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,20 @@
2424
msg = "`pip install bertopic[vision]` \n\n"
2525
MultiModalBackend = NotInstalled("Vision", "Vision", custom_msg=msg)
2626

27+
# Model2Vec Embeddings
28+
try:
29+
from bertopic.backend._model2vec import Model2VecBackend
30+
except ModuleNotFoundError:
31+
msg = "`pip install model2vec` \n\n"
32+
Model2VecBackend = NotInstalled("Model2Vec", "Model2Vec", custom_msg=msg)
33+
2734

2835
__all__ = [
2936
"BaseEmbedder",
3037
"WordDocEmbedder",
3138
"OpenAIBackend",
3239
"CohereBackend",
40+
"Model2VecBackend",
3341
"MultiModalBackend",
3442
"languages",
3543
]

bertopic/backend/_model2vec.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import numpy as np
2+
from typing import List, Union
3+
from model2vec import StaticModel
4+
from sklearn.feature_extraction.text import CountVectorizer
5+
6+
from bertopic.backend import BaseEmbedder
7+
8+
9+
class Model2VecBackend(BaseEmbedder):
10+
"""Model2Vec embedding model.
11+
12+
Arguments:
13+
embedding_model: Either a model2vec model or a
14+
string pointing to a model2vec model
15+
distill: Indicates whether to distill a sentence-transformers compatible model.
16+
The distillation will happen during fitting of the topic model.
17+
NOTE: Only works if `embedding_model` is a string.
18+
distill_kwargs: Keyword arguments to pass to the distillation process
19+
of `model2vec.distill.distill`
20+
distill_vectorizer: A CountVectorizer used for creating a custom vocabulary
21+
based on the same documents used for topic modeling.
22+
NOTE: If "vocabulary" is in `distill_kwargs`, this will be ignored.
23+
24+
Examples:
25+
To create a model, you can load in a string pointing to a
26+
model2vec model:
27+
28+
```python
29+
from bertopic.backend import Model2VecBackend
30+
31+
sentence_model = Model2VecBackend("minishlab/potion-base-8M")
32+
```
33+
34+
or you can instantiate a model yourself:
35+
36+
```python
37+
from bertopic.backend import Model2VecBackend
38+
from model2vec import StaticModel
39+
40+
embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")
41+
sentence_model = Model2VecBackend(embedding_model)
42+
```
43+
44+
If you want to distill a sentence-transformers model with the vocabulary of the documents,
45+
run the following:
46+
47+
```python
48+
from bertopic.backend import Model2VecBackend
49+
50+
sentence_model = Model2VecBackend("sentence-transformers/all-MiniLM-L6-v2", distill=True)
51+
```
52+
"""
53+
54+
def __init__(
55+
self,
56+
embedding_model: Union[str, StaticModel],
57+
distill: bool = False,
58+
distill_kwargs: dict = {},
59+
distill_vectorizer: str = None,
60+
):
61+
super().__init__()
62+
63+
self.distill = distill
64+
self.distill_kwargs = distill_kwargs
65+
self.distill_vectorizer = distill_vectorizer
66+
self._has_distilled = False
67+
68+
# When we distill, we need a string pointing to a sentence-transformer model
69+
if self.distill:
70+
self._check_model2vec_installation()
71+
if not self.distill_vectorizer:
72+
self.distill_vectorizer = CountVectorizer()
73+
if isinstance(embedding_model, str):
74+
self.embedding_model = embedding_model
75+
else:
76+
raise ValueError("Please pass a string pointing to a sentence-transformer model when distilling.")
77+
78+
# If we don't distill, we can pass a model2vec model directly or load from a string
79+
elif isinstance(embedding_model, StaticModel):
80+
self.embedding_model = embedding_model
81+
elif isinstance(embedding_model, str):
82+
self.embedding_model = StaticModel.from_pretrained(embedding_model)
83+
else:
84+
raise ValueError(
85+
"Please select a correct Model2Vec model: \n"
86+
"`from model2vec import StaticModel` \n"
87+
"`model = StaticModel.from_pretrained('minishlab/potion-base-8M')`"
88+
)
89+
90+
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
91+
"""Embed a list of n documents/words into an n-dimensional
92+
matrix of embeddings.
93+
94+
Arguments:
95+
documents: A list of documents or words to be embedded
96+
verbose: Controls the verbosity of the process
97+
98+
Returns:
99+
Document/words embeddings with shape (n, m) with `n` documents/words
100+
that each have an embeddings size of `m`
101+
"""
102+
# Distill the model
103+
if self.distill and not self._has_distilled:
104+
from model2vec.distill import distill
105+
106+
# Distill with the vocabulary of the documents
107+
if not self.distill_kwargs.get("vocabulary"):
108+
X = self.distill_vectorizer.fit_transform(documents)
109+
word_counts = np.array(X.sum(axis=0)).flatten()
110+
words = self.distill_vectorizer.get_feature_names_out()
111+
vocabulary = [word for word, _ in sorted(zip(words, word_counts), key=lambda x: x[1], reverse=True)]
112+
self.distill_kwargs["vocabulary"] = vocabulary
113+
114+
# Distill the model
115+
self.embedding_model = distill(self.embedding_model, **self.distill_kwargs)
116+
117+
# Distillation should happen only once and not for every embed call
118+
# The distillation should only happen the first time on the entire vocabulary
119+
self._has_distilled = True
120+
121+
# Embed the documents
122+
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
123+
return embeddings
124+
125+
def _check_model2vec_installation(self):
126+
try:
127+
from model2vec.distill import distill # noqa: F401
128+
except ImportError:
129+
raise ImportError("To distill a model using model2vec, you need to run `pip install model2vec[distill]`")

bertopic/backend/_sentencetransformers.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from typing import List, Union
33
from sentence_transformers import SentenceTransformer
4+
from sentence_transformers.models import StaticEmbedding
45

56
from bertopic.backend import BaseEmbedder
67

@@ -13,6 +14,9 @@ class SentenceTransformerBackend(BaseEmbedder):
1314
1415
Arguments:
1516
embedding_model: A sentence-transformers embedding model
17+
model2vec: Indicates whether `embedding_model` is a model2vec model.
18+
NOTE: Only works if `embedding_model` is a string.
19+
Otherwise, you can pass the model2vec model directly to `embedding_model`.
1620
1721
Examples:
1822
To create a model, you can load in a string pointing to a
@@ -25,20 +29,35 @@ class SentenceTransformerBackend(BaseEmbedder):
2529
```
2630
2731
or you can instantiate a model yourself:
32+
2833
```python
2934
from bertopic.backend import SentenceTransformerBackend
3035
from sentence_transformers import SentenceTransformer
3136
3237
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
3338
sentence_model = SentenceTransformerBackend(embedding_model)
3439
```
40+
41+
If you want to use a model2vec model without having to install model2vec,
42+
you can pass the model2vec model as a string:
43+
44+
```python
45+
from bertopic.backend import SentenceTransformerBackend
46+
from sentence_transformers import SentenceTransformer
47+
48+
embedding_model = SentenceTransformer("minishlab/potion-base-8M", model2vec=True)
49+
sentence_model = SentenceTransformerBackend(embedding_model)
50+
```
3551
"""
3652

37-
def __init__(self, embedding_model: Union[str, SentenceTransformer]):
53+
def __init__(self, embedding_model: Union[str, SentenceTransformer], model2vec: bool = False):
3854
super().__init__()
3955

4056
self._hf_model = None
41-
if isinstance(embedding_model, SentenceTransformer):
57+
if model2vec and isinstance(embedding_model, str):
58+
static_embedding = StaticEmbedding.from_model2vec(embedding_model)
59+
self.embedding_model = SentenceTransformer(modules=[static_embedding])
60+
elif isinstance(embedding_model, SentenceTransformer):
4261
self.embedding_model = embedding_model
4362
elif isinstance(embedding_model, str):
4463
self.embedding_model = SentenceTransformer(embedding_model)

bertopic/backend/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def select_backend(embedding_model, language: str = None, verbose: bool = False)
124124

125125
return HFTransformerBackend(embedding_model)
126126

127+
# Model2Vec embeddings
128+
if "model2vec" in str(type(embedding_model)):
129+
from ._model2vec import Model2VecBackend
130+
131+
return Model2VecBackend(embedding_model)
132+
127133
# Select embedding model based on language
128134
if language:
129135
try:

docs/getting_started/embeddings/embeddings.md

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ This modularity allows us not only to choose any embedding model to convert our
1414
When new state-of-the-art pre-trained embedding models are released, BERTopic will be able to use them. As a result, BERTopic grows with any new models being released.
1515
Out of the box, BERTopic supports several embedding techniques. In this section, we will go through several of them and how they can be implemented.
1616

17-
### **Sentence Transformers**
17+
## **Sentence Transformers**
1818
You can select any model from sentence-transformers [here](https://www.sbert.net/docs/pretrained_models.html)
1919
and pass it through BERTopic with `embedding_model`:
2020

@@ -47,7 +47,70 @@ topic_model = BERTopic(embedding_model=sentence_model)
4747
topic_model = BERTopic(embedding_model=embedding_model)
4848
```
4949

50-
### 🤗 Hugging Face Transformers
50+
## **Model2Vec**
51+
To use a blazingly fast [Model2Vec](https://github.com/MinishLab/model2vec) model, you first need to install model2vec:
52+
53+
```
54+
pip install model2vec
55+
```
56+
57+
Then, you can load in any of their models and pass it to BERTopic like so:
58+
59+
```python
60+
from model2vec import StaticModel
61+
embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")
62+
63+
topic_model = BERTopic(embedding_model=embedding_model)
64+
```
65+
66+
### **Distillation**
67+
68+
These models are extremely versatile and can be distilled from existing embedding model (like those compatible with `sentence-transformers`).
69+
This distillation process doesn't require a vocabulary (as it uses the tokenizer's vocabulary) but can benefit from having one. Fortunately, this allows you to
70+
use the vocabulary from your input documents to distill a model yourself.
71+
72+
Doing so requires you to install some additional dependencies of model2vec like so:
73+
74+
```
75+
pip install model2vec[distill]
76+
```
77+
78+
To then distill common embedding models, you need to import the `Model2VecBackend` from BERTopic:
79+
80+
```python
81+
from bertopic.backend import Model2VecBackend
82+
83+
# Choose a model to distill (a non-Model2Vec model)
84+
embedding_model = Model2VecBackend(
85+
"sentence-transformers/all-MiniLM-L6-v2",
86+
distill=True
87+
)
88+
89+
topic_model = BERTopic(embedding_model=embedding_model)
90+
```
91+
92+
You can also choose a custom vectorizer for creating the vocabulary and define custom arguments for the distillatio process:
93+
94+
```python
95+
from bertopic.backend import Model2VecBackend
96+
from sklearn.feature_extraction.text import CountVectorizer
97+
98+
# Choose a model to distill (a non-Model2Vec model)
99+
embedding_model = Model2VecBackend(
100+
"sentence-transformers/all-MiniLM-L6-v2",
101+
distill=True,
102+
distill_kwargs={"pca_dims": 256, "apply_zipf": True, "use_subword": True},
103+
distill_vectorizer=CountVectorizer(ngram_range=(1, 3))
104+
)
105+
106+
topic_model = BERTopic(embedding_model=embedding_model)
107+
```
108+
109+
!!! tip "Tip!"
110+
You can save the resulting model with `topic_model.embedding_model.embedding_model.save_pretrained("m2v_model")`.
111+
112+
113+
## **🤗 Hugging Face Transformers**
51114
To use a Hugging Face transformers model, load in a pipeline and point
52115
to any model found on their model hub (https://huggingface.co/models):
53116

@@ -61,7 +124,7 @@ topic_model = BERTopic(embedding_model=embedding_model)
61124
!!! tip "Tip!"
62125
These transformers also work quite well using `sentence-transformers` which has great optimizations tricks that make using it a bit faster.
63126

64-
### **Flair**
127+
## **Flair**
65128
[Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that
66129
is publicly available. Flair can be used as follows:
67130

@@ -87,7 +150,7 @@ document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding])
87150
topic_model = BERTopic(embedding_model=document_glove_embeddings)
88151
```
89152

90-
### **Spacy**
153+
## **Spacy**
91154
[Spacy](https://github.com/explosion/spaCy) is an amazing framework for processing text. There are
92155
many models available across many languages for modeling text.
93156

@@ -128,7 +191,7 @@ require_gpu(0)
128191
topic_model = BERTopic(embedding_model=nlp)
129192
```
130193

131-
### **Universal Sentence Encoder (USE)**
194+
## **Universal Sentence Encoder (USE)**
132195
The Universal Sentence Encoder encodes text into high-dimensional vectors that are used here
133196
for embedding the documents. The model is trained and optimized for greater-than-word length text,
134197
such as sentences, phrases, or short paragraphs.
@@ -141,7 +204,7 @@ embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-senten
141204
topic_model = BERTopic(embedding_model=embedding_model)
142205
```
143206

144-
### **Gensim**
207+
## **Gensim**
145208
BERTopic supports the `gensim.downloader` module, which allows it to download any word embedding model supported by Gensim.
146209
Typically, these are Glove, Word2Vec, or FastText embeddings:
147210

@@ -155,7 +218,7 @@ topic_model = BERTopic(embedding_model=ft)
155218
Gensim is primarily used for Word Embedding models. This works typically best for short documents since the word embeddings are pooled.
156219

157220

158-
### **Scikit-Learn Embeddings**
221+
## **Scikit-Learn Embeddings**
159222
Scikit-Learn is a framework for more than just machine learning.
160223
It offers many preprocessing tools, some of which can be used to create representations
161224
for text. Many of these tools are relatively lightweight and do not require a GPU.
@@ -187,7 +250,7 @@ topic_model = BERTopic(embedding_model=pipe)
187250
it does not support the `bertopic.representation` models.
188251

189252

190-
### OpenAI
253+
## **OpenAI**
191254
To use OpenAI's external API, we need to define our key and explicitly call `bertopic.backend.OpenAIBackend`
192255
to be used in our topic model:
193256

@@ -202,7 +265,7 @@ topic_model = BERTopic(embedding_model=embedding_model)
202265
```
203266

204267

205-
### Cohere
268+
## **Cohere**
206269
To use Cohere's external API, we need to define our key and explicitly call `bertopic.backend.CohereBackend`
207270
to be used in our topic model:
208271

@@ -216,7 +279,7 @@ embedding_model = CohereBackend(client)
216279
topic_model = BERTopic(embedding_model=embedding_model)
217280
```
218281

219-
### Multimodal
282+
## **Multimodal**
220283
To create embeddings for both text and images in the same vector space, we can use the `MultiModalBackend`.
221284
This model uses a clip-vit based model that is capable of embedding text, images, or both:
222285

@@ -235,7 +298,7 @@ doc_image_embeddings = model.embed(docs, images)
235298
```
236299

237300

238-
### **Custom Backend**
301+
## **Custom Backend**
239302
If your backend or model cannot be found in the ones currently available, you can use the `bertopic.backend.BaseEmbedder` class to
240303
create your backend. Below, you will find an example of creating a SentenceTransformer backend for BERTopic:
241304

@@ -260,7 +323,7 @@ custom_embedder = CustomEmbedder(embedding_model=embedding_model)
260323
topic_model = BERTopic(embedding_model=custom_embedder)
261324
```
262325

263-
### **Custom Embeddings**
326+
## **Custom Embeddings**
264327
The base models in BERTopic are BERT-based models that work well with document similarity tasks. Your documents,
265328
however, might be too specific for a general pre-trained model to be used. Fortunately, you can use the embedding
266329
model in BERTopic to create document features.
@@ -283,7 +346,7 @@ topics, probs = topic_model.fit_transform(docs, embeddings)
283346
As you can see above, we used a SentenceTransformer model to create the embedding. You could also have used
284347
`🤗 transformers`, `Doc2Vec`, or any other embedding method.
285348

286-
#### **TF-IDF**
349+
### **TF-IDF**
287350
As mentioned above, any embedding technique can be used. However, when running UMAP, the typical distance metric is
288351
`cosine` which does not work quite well for a TF-IDF matrix. Instead, BERTopic will recognize that a sparse matrix
289352
is passed and use `hellinger` instead which works quite well for the similarity between probability distributions.

0 commit comments

Comments
 (0)