Skip to content

Commit 7724551

Browse files
authored
Allow execution without plotly (#2401)
1 parent 6669201 commit 7724551

File tree

5 files changed

+70
-19
lines changed

5 files changed

+70
-19
lines changed

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
test:
22
pytest
33

4+
test-no-plotly:
5+
uv sync --extra test
6+
uv pip uninstall plotly
7+
pytest tests/test_other.py -k plotly --pdb
8+
uv sync --extra test
9+
pytest tests/test_other.py -k plotly
10+
411
coverage:
512
pytest --cov
613

bertopic/_bertopic.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from collections import defaultdict, Counter
2828
from scipy.sparse import csr_matrix
2929
from scipy.cluster import hierarchy as sch
30+
from importlib.util import find_spec
3031

3132
# Typing
3233
import sys
@@ -35,7 +36,21 @@
3536
from typing import Literal
3637
else:
3738
from typing_extensions import Literal
38-
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
39+
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING
40+
41+
# Plotting
42+
if find_spec("plotly") is None:
43+
from bertopic._utils import MockPlotlyModule
44+
45+
plotting = MockPlotlyModule()
46+
47+
else:
48+
from bertopic import plotting
49+
50+
if TYPE_CHECKING:
51+
import plotly.graph_objs as go
52+
import matplotlib.figure as fig
53+
3954

4055
# Models
4156
try:
@@ -54,7 +69,6 @@
5469
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
5570

5671
# BERTopic
57-
from bertopic import plotting
5872
from bertopic.cluster import BaseCluster
5973
from bertopic.backend import BaseEmbedder
6074
from bertopic.representation._mmr import mmr
@@ -74,9 +88,6 @@
7488
)
7589
import bertopic._save_utils as save_utils
7690

77-
# Visualization
78-
import plotly.graph_objects as go
79-
8091
logger = MyLogger()
8192
logger.configure("WARNING")
8293

@@ -2541,7 +2552,7 @@ def visualize_topics(
25412552
title: str = "<b>Intertopic Distance Map</b>",
25422553
width: int = 650,
25432554
height: int = 650,
2544-
) -> go.Figure:
2555+
) -> "go.Figure":
25452556
"""Visualize topics, their sizes, and their corresponding words.
25462557
25472558
This visualization is highly inspired by LDAvis, a great visualization
@@ -2599,7 +2610,7 @@ def visualize_documents(
25992610
title: str = "<b>Documents and Topics</b>",
26002611
width: int = 1200,
26012612
height: int = 750,
2602-
) -> go.Figure:
2613+
) -> "go.Figure":
26032614
"""Visualize documents and their topics in 2D.
26042615
26052616
Arguments:
@@ -2701,7 +2712,7 @@ def visualize_document_datamap(
27012712
topic_prefix: bool = False,
27022713
datamap_kwds: dict = {},
27032714
int_datamap_kwds: dict = {},
2704-
):
2715+
) -> "fig.Figure":
27052716
"""Visualize documents and their topics in 2D as a static plot for publication using
27062717
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
27072718
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
@@ -2812,7 +2823,7 @@ def visualize_hierarchical_documents(
28122823
title: str = "<b>Hierarchical Documents and Topics</b>",
28132824
width: int = 1200,
28142825
height: int = 750,
2815-
) -> go.Figure:
2826+
) -> "go.Figure":
28162827
"""Visualize documents and their topics in 2D at different levels of hierarchy.
28172828
28182829
Arguments:
@@ -2924,7 +2935,7 @@ def visualize_term_rank(
29242935
title: str = "<b>Term score decline per Topic</b>",
29252936
width: int = 800,
29262937
height: int = 500,
2927-
) -> go.Figure:
2938+
) -> "go.Figure":
29282939
"""Visualize the ranks of all terms across all topics.
29292940
29302941
Each topic is represented by a set of words. These words, however,
@@ -2989,7 +3000,7 @@ def visualize_topics_over_time(
29893000
title: str = "<b>Topics over Time</b>",
29903001
width: int = 1250,
29913002
height: int = 450,
2992-
) -> go.Figure:
3003+
) -> "go.Figure":
29933004
"""Visualize topics over time.
29943005
29953006
Arguments:
@@ -3045,7 +3056,7 @@ def visualize_topics_per_class(
30453056
title: str = "<b>Topics per Class</b>",
30463057
width: int = 1250,
30473058
height: int = 900,
3048-
) -> go.Figure:
3059+
) -> "go.Figure":
30493060
"""Visualize topics per class.
30503061
30513062
Arguments:
@@ -3099,7 +3110,7 @@ def visualize_distribution(
30993110
title: str = "<b>Topic Probability Distribution</b>",
31003111
width: int = 800,
31013112
height: int = 600,
3102-
) -> go.Figure:
3113+
) -> "go.Figure":
31033114
"""Visualize the distribution of topic probabilities.
31043115
31053116
Arguments:
@@ -3206,7 +3217,7 @@ def visualize_hierarchy(
32063217
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
32073218
distance_function: Callable[[csr_matrix], csr_matrix] = None,
32083219
color_threshold: int = 1,
3209-
) -> go.Figure:
3220+
) -> "go.Figure":
32103221
"""Visualize a hierarchical structure of the topics.
32113222
32123223
A ward linkage function is used to perform the
@@ -3302,7 +3313,7 @@ def visualize_heatmap(
33023313
title: str = "<b>Similarity Matrix</b>",
33033314
width: int = 800,
33043315
height: int = 800,
3305-
) -> go.Figure:
3316+
) -> "go.Figure":
33063317
"""Visualize a heatmap of the topic's similarity matrix.
33073318
33083319
Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics,
@@ -3362,7 +3373,7 @@ def visualize_barchart(
33623373
width: int = 250,
33633374
height: int = 250,
33643375
autoscale: bool = False,
3365-
) -> go.Figure:
3376+
) -> "go.Figure":
33663377
"""Visualize a barchart of selected topics.
33673378
33683379
Arguments:

bertopic/_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable
55
from scipy.sparse import csr_matrix
66
from scipy.spatial.distance import squareform
7-
from typing import Optional, Union, Tuple
7+
from typing import Optional, Union, Tuple, Any
88

99

1010
class MyLogger:
@@ -226,3 +226,13 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray:
226226
repr_, ctfidf_used = embeddings, False
227227

228228
return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used
229+
230+
231+
class MockPlotlyModule:
232+
"""Mock module that raises an error when plotly functions are called."""
233+
234+
def __getattr__(self, name: str) -> Any:
235+
def mock_function(*args, **kwargs):
236+
raise ImportError(f"Plotly is required to use '{name}'. Install it with uv pip install plotly")
237+
238+
return mock_function

docs/getting_started/tips_and_tricks/tips_and_tricks.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,11 @@ embeddings = normalize(embeddings)
196196

197197
The default embedding model in BERTopic is one of the amazing sentence-transformers models, namely `"all-MiniLM-L6-v2"`. Although this model performs well out of the box, it typically needs a GPU to transform the documents into embeddings in a reasonable time. Moreover, the installation requires `pytorch` which often results in a rather large environment, memory-wise.
198198

199-
Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, and/or `HDBSCAN`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:
199+
Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, `HDBSCAN` and/or `plotly`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:
200200

201201
```bash
202202
pip install --no-deps bertopic
203-
pip install --upgrade numpy pandas scikit-learn tqdm plotly pyyaml
203+
pip install --upgrade numpy pandas scikit-learn tqdm pyyaml
204204
```
205205

206206
This installs a bare-bones version of BERTopic. If you want to use UMAP and Model2Vec for instance, you'll need to first install them:

tests/test_other.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from bertopic import BERTopic
2+
from bertopic.dimensionality import BaseDimensionalityReduction
3+
4+
try:
5+
from plotly.graph_objects import Figure
6+
except ImportError:
7+
Figure = None
28

39

410
def test_load_save_model():
@@ -20,3 +26,20 @@ def test_get_params():
2026
assert params["n_gram_range"] == (1, 1)
2127
assert params["min_topic_size"] == 10
2228
assert params["language"] == "english"
29+
30+
31+
def test_no_plotly():
32+
model = BERTopic(
33+
language="Dutch",
34+
embedding_model=None,
35+
min_topic_size=2,
36+
top_n_words=1,
37+
umap_model=BaseDimensionalityReduction(),
38+
)
39+
model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10)
40+
41+
try:
42+
out = model.visualize_topics()
43+
assert isinstance(out, Figure) if Figure else False
44+
except ImportError as e:
45+
assert "Plotly is required to use" in str(e)

0 commit comments

Comments
 (0)