Skip to content

Commit 0c930d2

Browse files
authored
Light-weight installation without UMAP and HDBSCAN (#2289)
1 parent 68cc1a7 commit 0c930d2

File tree

14 files changed

+169
-634
lines changed

14 files changed

+169
-634
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ pip install bertopic[flair,gensim,spacy,use]
6363
pip install bertopic[vision]
6464
```
6565

66+
For a *light-weight installation* without transformers, UMAP and/or HDBSCAN (for training with Model2Vec or perhaps for inference), see [this tutorial](https://maartengr.github.io/BERTopic/getting_started/tips_and_tricks/tips_and_tricks.html#lightweight-installation).
67+
6668
## Getting Started
6769
For an in-depth overview of the features of BERTopic
6870
you can check the [**full documentation**](https://maartengr.github.io/BERTopic/) or you can follow along

bertopic/_bertopic.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,18 @@
3737
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
3838

3939
# Models
40-
import hdbscan
41-
from umap import UMAP
40+
try:
41+
from hdbscan import HDBSCAN
42+
43+
HAS_HDBSCAN = True
44+
except (ImportError, ModuleNotFoundError):
45+
HAS_HDBSCAN = False
46+
from sklearn.cluster import HDBSCAN as SK_HDBSCAN
47+
4248
from sklearn.preprocessing import normalize
4349
from sklearn import __version__ as sklearn_version
4450
from sklearn.cluster import AgglomerativeClustering
51+
from sklearn.decomposition import PCA
4552
from sklearn.metrics.pairwise import cosine_similarity
4653
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
4754

@@ -143,8 +150,8 @@ def __init__(
143150
zeroshot_topic_list: List[str] = None,
144151
zeroshot_min_similarity: float = 0.7,
145152
embedding_model=None,
146-
umap_model: UMAP = None,
147-
hdbscan_model: hdbscan.HDBSCAN = None,
153+
umap_model=None,
154+
hdbscan_model=None,
148155
vectorizer_model: CountVectorizer = None,
149156
ctfidf_model: TfidfTransformer = None,
150157
representation_model: BaseRepresentation = None,
@@ -247,22 +254,38 @@ def __init__(
247254
self.representation_model = representation_model
248255

249256
# UMAP or another algorithm that has .fit and .transform functions
250-
self.umap_model = umap_model or UMAP(
251-
n_neighbors=15,
252-
n_components=5,
253-
min_dist=0.0,
254-
metric="cosine",
255-
low_memory=self.low_memory,
256-
)
257+
if umap_model is not None:
258+
self.umap_model = umap_model
259+
else:
260+
try:
261+
from umap import UMAP
262+
263+
self.umap_model = UMAP(
264+
n_neighbors=15,
265+
n_components=5,
266+
min_dist=0.0,
267+
metric="cosine",
268+
low_memory=self.low_memory,
269+
)
270+
except (ImportError, ModuleNotFoundError):
271+
self.umap_model = PCA(n_components=5)
257272

258273
# HDBSCAN or another clustering algorithm that has .fit and .predict functions and
259274
# the .labels_ variable to extract the labels
260-
self.hdbscan_model = hdbscan_model or hdbscan.HDBSCAN(
261-
min_cluster_size=self.min_topic_size,
262-
metric="euclidean",
263-
cluster_selection_method="eom",
264-
prediction_data=True,
265-
)
275+
276+
if hdbscan_model is not None:
277+
self.hdbscan_model = hdbscan_model
278+
elif HAS_HDBSCAN:
279+
self.hdbscan_model = HDBSCAN(
280+
min_cluster_size=self.min_topic_size,
281+
metric="euclidean",
282+
cluster_selection_method="eom",
283+
prediction_data=True,
284+
)
285+
else:
286+
self.hdbscan_model = SK_HDBSCAN(
287+
min_cluster_size=self.min_topic_size, metric="euclidean", cluster_selection_method="eom", n_jobs=-1
288+
)
266289

267290
# Public attributes
268291
self.topics_ = None
@@ -326,7 +349,7 @@ def fit(
326349
images: List[str] = None,
327350
y: Union[List[int], np.ndarray] = None,
328351
):
329-
"""Fit the models (Bert, UMAP, and, HDBSCAN) on a collection of documents and generate topics.
352+
"""Fit the models on a collection of documents and generate topics.
330353
331354
Arguments:
332355
documents: A list of documents to fit on
@@ -684,9 +707,7 @@ def partial_fit(
684707
# Checks
685708
check_embeddings_shape(embeddings, documents)
686709
if not hasattr(self.hdbscan_model, "partial_fit"):
687-
raise ValueError(
688-
"In order to use `.partial_fit`, the cluster model should have " "a `.partial_fit` function."
689-
)
710+
raise ValueError("In order to use `.partial_fit`, the cluster model should have a `.partial_fit` function.")
690711

691712
# Prepare documents
692713
if isinstance(documents, str):
@@ -1524,7 +1545,7 @@ def update_topics(
15241545

15251546
if top_n_words > 100:
15261547
logger.warning(
1527-
"Note that extracting more than 100 words from a sparse " "can slow down computation quite a bit."
1548+
"Note that extracting more than 100 words from a sparse can slow down computation quite a bit."
15281549
)
15291550
self.top_n_words = top_n_words
15301551
self.vectorizer_model = vectorizer_model or CountVectorizer(ngram_range=n_gram_range)
@@ -2007,7 +2028,7 @@ def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) ->
20072028
custom_labels = topic_labels
20082029
else:
20092030
raise ValueError(
2010-
"Make sure that `topic_labels` contains the same number " "of labels as there are topics."
2031+
"Make sure that `topic_labels` contains the same number of labels as there are topics."
20112032
)
20122033

20132034
self.custom_labels_ = custom_labels
@@ -2124,9 +2145,7 @@ def merge_topics(
21242145
for topic in topic_group:
21252146
mapping[topic] = topic_group[0]
21262147
else:
2127-
raise ValueError(
2128-
"Make sure that `topics_to_merge` is either" "a list of topics or a list of list of topics."
2129-
)
2148+
raise ValueError("Make sure that `topics_to_merge` is eithera list of topics or a list of list of topics.")
21302149

21312150
# Track mappings and sizes of topics for merging topic embeddings
21322151
mappings = defaultdict(list)
@@ -3769,7 +3788,7 @@ def _cluster_embeddings(
37693788
partial_fit: bool = False,
37703789
y: np.ndarray = None,
37713790
) -> Tuple[pd.DataFrame, np.ndarray]:
3772-
"""Cluster UMAP embeddings with HDBSCAN.
3791+
"""Cluster UMAP reduced embeddings with HDBSCAN.
37733792
37743793
Arguments:
37753794
umap_embeddings: The reduced sentence embeddings with UMAP
@@ -4473,12 +4492,18 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
44734492
self.c_tf_idf_, self.topic_embeddings_, use_ctfidf, output_ndarray=True
44744493
)[0]
44754494
norm_data = normalize(embeddings, norm="l2")
4476-
predictions = hdbscan.HDBSCAN(
4477-
min_cluster_size=2,
4478-
metric="euclidean",
4479-
cluster_selection_method="eom",
4480-
prediction_data=True,
4481-
).fit_predict(norm_data[self._outliers :])
4495+
4496+
if HAS_HDBSCAN:
4497+
predictions = HDBSCAN(
4498+
min_cluster_size=2,
4499+
metric="euclidean",
4500+
cluster_selection_method="eom",
4501+
prediction_data=True,
4502+
).fit_predict(norm_data[self._outliers :])
4503+
else:
4504+
predictions = SK_HDBSCAN(
4505+
min_cluster_size=2, metric="euclidean", cluster_selection_method="eom", n_jobs=-1
4506+
).fit_predict(norm_data[self._outliers :])
44824507

44834508
# Map similar topics
44844509
mapped_topics = {

bertopic/_save_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -461,22 +461,36 @@ def get_package_versions():
461461
try:
462462
import platform
463463
from numpy import __version__ as np_version
464+
from pandas import __version__ as pandas_version
465+
from sklearn import __version__ as sklearn_version
466+
from plotly import __version__ as plotly_version
464467

465468
try:
466469
from importlib.metadata import version
467470

468471
hdbscan_version = version("hdbscan")
469-
except: # noqa: E722
472+
except (ImportError, ModuleNotFoundError):
470473
hdbscan_version = None
471474

472-
from umap import __version__ as umap_version
473-
from pandas import __version__ as pandas_version
474-
from sklearn import __version__ as sklearn_version
475-
from sentence_transformers import __version__ as sbert_version
476-
from numba import __version__ as numba_version
477-
from transformers import __version__ as transformers_version
475+
try:
476+
from umap import __version__ as umap_version
477+
except (ImportError, ModuleNotFoundError):
478+
umap_version = None
478479

479-
from plotly import __version__ as plotly_version
480+
try:
481+
from sentence_transformers import __version__ as sbert_version
482+
except (ImportError, ModuleNotFoundError):
483+
sbert_version = None
484+
485+
try:
486+
from numba import __version__ as numba_version
487+
except (ImportError, ModuleNotFoundError):
488+
numba_version = None
489+
490+
try:
491+
from transformers import __version__ as transformers_version
492+
except (ImportError, ModuleNotFoundError):
493+
transformers_version = None
480494

481495
return {
482496
"Numpy": np_version,

bertopic/_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,7 @@ def check_is_fitted(topic_model):
7474
Raises:
7575
ValueError: If the matches were not found.
7676
"""
77-
msg = (
78-
"This %(name)s instance is not fitted yet. Call 'fit' with "
79-
"appropriate arguments before using this estimator."
80-
)
77+
msg = "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator."
8178

8279
if topic_model.topics_ is None:
8380
raise ValueError(msg % {"name": type(topic_model).__name__})
@@ -131,11 +128,11 @@ def validate_distance_matrix(X, n_samples):
131128
# check it has correct size
132129
n = s[0]
133130
if n != (n_samples * (n_samples - 1) / 2):
134-
raise ValueError("The condensed distance matrix must have " "shape (n*(n-1)/2,).")
131+
raise ValueError("The condensed distance matrix must have shape (n*(n-1)/2,).")
135132
elif len(s) == 2:
136133
# check it has correct size
137134
if (s[0] != n_samples) or (s[1] != n_samples):
138-
raise ValueError("The distance matrix must be of shape " "(n, n) where n is the number of samples.")
135+
raise ValueError("The distance matrix must be of shape (n, n) where n is the number of samples.")
139136
# force zero diagonal and convert to condensed
140137
np.fill_diagonal(X, 0)
141138
X = squareform(X)

bertopic/cluster/_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import hdbscan
21
import numpy as np
32

43

@@ -15,6 +14,11 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
1514
embeddings: Input embeddings for "approximate_predict"
1615
and "membership_vector"
1716
"""
17+
try:
18+
import hdbscan
19+
except (ImportError, ModuleNotFoundError):
20+
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()
21+
1822
# Approximate predict
1923
if func == "approximate_predict":
2024
if isinstance(model, hdbscan.HDBSCAN):
@@ -62,6 +66,11 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
6266

6367
def is_supported_hdbscan(model):
6468
"""Check whether the input model is a supported HDBSCAN-like model."""
69+
try:
70+
import hdbscan
71+
except (ImportError, ModuleNotFoundError):
72+
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()
73+
6574
if isinstance(model, hdbscan.HDBSCAN):
6675
return True
6776

bertopic/plotting/_approximate_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def visualize_approximate_distribution(
7575
df = pd.DataFrame(topic_token_distribution).T
7676

7777
df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
78-
df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
78+
df.columns = [f"{token}{' ' * i}" for i, token in enumerate(tokens)]
7979
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers :]
8080
df = df.loc[(df.sum(axis=1) != 0), :]
8181

bertopic/plotting/_datamap.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pandas as pd
33
from typing import List, Union
4-
from umap import UMAP
54
from warnings import warn
65

76
try:
@@ -122,8 +121,15 @@ def visualize_document_datamap(
122121

123122
# Reduce input embeddings
124123
if reduced_embeddings is None:
125-
umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric="cosine").fit(embeddings_to_reduce)
126-
embeddings_2d = umap_model.embedding_
124+
try:
125+
from umap import UMAP
126+
127+
umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric="cosine").fit(embeddings_to_reduce)
128+
embeddings_2d = umap_model.embedding_
129+
except (ImportError, ModuleNotFoundError):
130+
raise ModuleNotFoundError(
131+
"UMAP is required if the embeddings are not yet reduced in dimensionality. Please install it using `pip install umap-learn`."
132+
)
127133
else:
128134
embeddings_2d = reduced_embeddings
129135

bertopic/plotting/_documents.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pandas as pd
33
import plotly.graph_objects as go
44

5-
from umap import UMAP
65
from typing import List, Union
76

87

@@ -120,8 +119,15 @@ def visualize_documents(
120119

121120
# Reduce input embeddings
122121
if reduced_embeddings is None:
123-
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine").fit(embeddings_to_reduce)
124-
embeddings_2d = umap_model.embedding_
122+
try:
123+
from umap import UMAP
124+
125+
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine").fit(embeddings_to_reduce)
126+
embeddings_2d = umap_model.embedding_
127+
except (ImportError, ModuleNotFoundError):
128+
raise ModuleNotFoundError(
129+
"UMAP is required if the embeddings are not yet reduced in dimensionality. Please install it using `pip install umap-learn`."
130+
)
125131
elif sample is not None and reduced_embeddings is not None:
126132
embeddings_2d = reduced_embeddings[indices]
127133
elif sample is None and reduced_embeddings is not None:

bertopic/plotting/_heatmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def visualize_heatmap(
7777
sorted_topics = topics
7878
if n_clusters:
7979
if n_clusters >= len(set(topics)):
80-
raise ValueError("Make sure to set `n_clusters` lower than " "the total number of unique topics.")
80+
raise ValueError("Make sure to set `n_clusters` lower than the total number of unique topics.")
8181

8282
distance_matrix = cosine_similarity(embeddings[topics])
8383
Z = linkage(distance_matrix, "ward")

0 commit comments

Comments
 (0)