diff --git a/Makefile b/Makefile
index bc7f2ba1..bc9f0cd8 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,13 @@
test:
pytest
+test-no-plotly:
+ uv sync --extra test
+ uv pip uninstall plotly
+ pytest tests/test_other.py -k plotly --pdb
+ uv sync --extra test
+ pytest tests/test_other.py -k plotly
+
coverage:
pytest --cov
diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py
index a59c0798..1ac66a13 100644
--- a/bertopic/_bertopic.py
+++ b/bertopic/_bertopic.py
@@ -26,6 +26,7 @@
from collections import defaultdict, Counter
from scipy.sparse import csr_matrix
from scipy.cluster import hierarchy as sch
+from importlib.util import find_spec
# Typing
import sys
@@ -34,7 +35,21 @@
from typing import Literal
else:
from typing_extensions import Literal
-from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
+from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING
+
+# Plotting
+if find_spec("plotly") is None:
+ from bertopic._utils import MockPlotlyModule
+
+ plotting = MockPlotlyModule()
+
+else:
+ from bertopic import plotting
+
+ if TYPE_CHECKING:
+ import plotly.graph_objs as go
+ import matplotlib.figure as fig
+
# Models
try:
@@ -53,7 +68,6 @@
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
# BERTopic
-from bertopic import plotting
from bertopic.cluster import BaseCluster
from bertopic.backend import BaseEmbedder
from bertopic.representation._mmr import mmr
@@ -73,9 +87,6 @@
)
import bertopic._save_utils as save_utils
-# Visualization
-import plotly.graph_objects as go
-
logger = MyLogger()
logger.configure("WARNING")
@@ -2405,7 +2416,7 @@ def visualize_topics(
title: str = "Intertopic Distance Map",
width: int = 650,
height: int = 650,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize topics, their sizes, and their corresponding words.
This visualization is highly inspired by LDAvis, a great visualization
@@ -2463,7 +2474,7 @@ def visualize_documents(
title: str = "Documents and Topics",
width: int = 1200,
height: int = 750,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize documents and their topics in 2D.
Arguments:
@@ -2565,7 +2576,7 @@ def visualize_document_datamap(
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
- ):
+ ) -> "fig.Figure":
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
@@ -2676,7 +2687,7 @@ def visualize_hierarchical_documents(
title: str = "Hierarchical Documents and Topics",
width: int = 1200,
height: int = 750,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize documents and their topics in 2D at different levels of hierarchy.
Arguments:
@@ -2788,7 +2799,7 @@ def visualize_term_rank(
title: str = "Term score decline per Topic",
width: int = 800,
height: int = 500,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize the ranks of all terms across all topics.
Each topic is represented by a set of words. These words, however,
@@ -2853,7 +2864,7 @@ def visualize_topics_over_time(
title: str = "Topics over Time",
width: int = 1250,
height: int = 450,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize topics over time.
Arguments:
@@ -2909,7 +2920,7 @@ def visualize_topics_per_class(
title: str = "Topics per Class",
width: int = 1250,
height: int = 900,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize topics per class.
Arguments:
@@ -2963,7 +2974,7 @@ def visualize_distribution(
title: str = "Topic Probability Distribution",
width: int = 800,
height: int = 600,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize the distribution of topic probabilities.
Arguments:
@@ -3070,7 +3081,7 @@ def visualize_hierarchy(
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
distance_function: Callable[[csr_matrix], csr_matrix] = None,
color_threshold: int = 1,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize a hierarchical structure of the topics.
A ward linkage function is used to perform the
@@ -3166,7 +3177,7 @@ def visualize_heatmap(
title: str = "Similarity Matrix",
width: int = 800,
height: int = 800,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize a heatmap of the topic's similarity matrix.
Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics,
@@ -3226,7 +3237,7 @@ def visualize_barchart(
width: int = 250,
height: int = 250,
autoscale: bool = False,
- ) -> go.Figure:
+ ) -> "go.Figure":
"""Visualize a barchart of selected topics.
Arguments:
diff --git a/bertopic/_utils.py b/bertopic/_utils.py
index 526178d0..035c6acb 100644
--- a/bertopic/_utils.py
+++ b/bertopic/_utils.py
@@ -4,7 +4,7 @@
from collections.abc import Iterable
from scipy.sparse import csr_matrix
from scipy.spatial.distance import squareform
-from typing import Optional, Union, Tuple
+from typing import Optional, Union, Tuple, Any
class MyLogger:
@@ -226,3 +226,13 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray:
repr_, ctfidf_used = embeddings, False
return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used
+
+
+class MockPlotlyModule:
+ """Mock module that raises an error when plotly functions are called."""
+
+ def __getattr__(self, name: str) -> Any:
+ def mock_function(*args, **kwargs):
+ raise ImportError(f"Plotly is required to use '{name}'. Install it with uv pip install plotly")
+
+ return mock_function
diff --git a/docs/getting_started/tips_and_tricks/tips_and_tricks.md b/docs/getting_started/tips_and_tricks/tips_and_tricks.md
index 115dcc5e..a6b684bf 100644
--- a/docs/getting_started/tips_and_tricks/tips_and_tricks.md
+++ b/docs/getting_started/tips_and_tricks/tips_and_tricks.md
@@ -196,11 +196,11 @@ embeddings = normalize(embeddings)
The default embedding model in BERTopic is one of the amazing sentence-transformers models, namely `"all-MiniLM-L6-v2"`. Although this model performs well out of the box, it typically needs a GPU to transform the documents into embeddings in a reasonable time. Moreover, the installation requires `pytorch` which often results in a rather large environment, memory-wise.
-Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, and/or `HDBSCAN`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:
+Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, `HDBSCAN` and/or `plotly`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:
```bash
pip install --no-deps bertopic
-pip install --upgrade numpy pandas scikit-learn tqdm plotly pyyaml
+pip install --upgrade numpy pandas scikit-learn tqdm pyyaml
```
This installs a bare-bones version of BERTopic. If you want to use UMAP and Model2Vec for instance, you'll need to first install them:
diff --git a/tests/test_other.py b/tests/test_other.py
index 309c9900..59dea1a7 100644
--- a/tests/test_other.py
+++ b/tests/test_other.py
@@ -1,4 +1,10 @@
from bertopic import BERTopic
+from bertopic.dimensionality import BaseDimensionalityReduction
+
+try:
+ from plotly.graph_objects import Figure
+except ImportError:
+ Figure = None
def test_load_save_model():
@@ -20,3 +26,20 @@ def test_get_params():
assert params["n_gram_range"] == (1, 1)
assert params["min_topic_size"] == 10
assert params["language"] == "english"
+
+
+def test_no_plotly():
+ model = BERTopic(
+ language="Dutch",
+ embedding_model=None,
+ min_topic_size=2,
+ top_n_words=1,
+ umap_model=BaseDimensionalityReduction(),
+ )
+ model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10)
+
+ try:
+ out = model.visualize_topics()
+ assert isinstance(out, Figure) if Figure else False
+ except ImportError as e:
+ assert "Plotly is required to use" in str(e)