Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
test:
pytest

test-no-plotly:
uv sync --extra test
uv pip uninstall plotly
pytest tests/test_other.py -k plotly --pdb
uv sync --extra test
pytest tests/test_other.py -k plotly

coverage:
pytest --cov

Expand Down
43 changes: 27 additions & 16 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections import defaultdict, Counter
from scipy.sparse import csr_matrix
from scipy.cluster import hierarchy as sch
from importlib.util import find_spec

# Typing
import sys
Expand All @@ -34,7 +35,21 @@
from typing import Literal
else:
from typing_extensions import Literal
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING

# Plotting
if find_spec("plotly") is None:
from bertopic._utils import MockPlotlyModule

plotting = MockPlotlyModule()

else:
from bertopic import plotting

if TYPE_CHECKING:
import plotly.graph_objs as go
import matplotlib.figure as fig


# Models
try:
Expand All @@ -53,7 +68,6 @@
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

# BERTopic
from bertopic import plotting
from bertopic.cluster import BaseCluster
from bertopic.backend import BaseEmbedder
from bertopic.representation._mmr import mmr
Expand All @@ -73,9 +87,6 @@
)
import bertopic._save_utils as save_utils

# Visualization
import plotly.graph_objects as go

logger = MyLogger()
logger.configure("WARNING")

Expand Down Expand Up @@ -2405,7 +2416,7 @@ def visualize_topics(
title: str = "<b>Intertopic Distance Map</b>",
width: int = 650,
height: int = 650,
) -> go.Figure:
) -> "go.Figure":
"""Visualize topics, their sizes, and their corresponding words.

This visualization is highly inspired by LDAvis, a great visualization
Expand Down Expand Up @@ -2463,7 +2474,7 @@ def visualize_documents(
title: str = "<b>Documents and Topics</b>",
width: int = 1200,
height: int = 750,
) -> go.Figure:
) -> "go.Figure":
"""Visualize documents and their topics in 2D.

Arguments:
Expand Down Expand Up @@ -2565,7 +2576,7 @@ def visualize_document_datamap(
topic_prefix: bool = False,
datamap_kwds: dict = {},
int_datamap_kwds: dict = {},
):
) -> "fig.Figure":
"""Visualize documents and their topics in 2D as a static plot for publication using
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
Expand Down Expand Up @@ -2676,7 +2687,7 @@ def visualize_hierarchical_documents(
title: str = "<b>Hierarchical Documents and Topics</b>",
width: int = 1200,
height: int = 750,
) -> go.Figure:
) -> "go.Figure":
"""Visualize documents and their topics in 2D at different levels of hierarchy.

Arguments:
Expand Down Expand Up @@ -2788,7 +2799,7 @@ def visualize_term_rank(
title: str = "<b>Term score decline per Topic</b>",
width: int = 800,
height: int = 500,
) -> go.Figure:
) -> "go.Figure":
"""Visualize the ranks of all terms across all topics.

Each topic is represented by a set of words. These words, however,
Expand Down Expand Up @@ -2853,7 +2864,7 @@ def visualize_topics_over_time(
title: str = "<b>Topics over Time</b>",
width: int = 1250,
height: int = 450,
) -> go.Figure:
) -> "go.Figure":
"""Visualize topics over time.

Arguments:
Expand Down Expand Up @@ -2909,7 +2920,7 @@ def visualize_topics_per_class(
title: str = "<b>Topics per Class</b>",
width: int = 1250,
height: int = 900,
) -> go.Figure:
) -> "go.Figure":
"""Visualize topics per class.

Arguments:
Expand Down Expand Up @@ -2963,7 +2974,7 @@ def visualize_distribution(
title: str = "<b>Topic Probability Distribution</b>",
width: int = 800,
height: int = 600,
) -> go.Figure:
) -> "go.Figure":
"""Visualize the distribution of topic probabilities.

Arguments:
Expand Down Expand Up @@ -3070,7 +3081,7 @@ def visualize_hierarchy(
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
distance_function: Callable[[csr_matrix], csr_matrix] = None,
color_threshold: int = 1,
) -> go.Figure:
) -> "go.Figure":
"""Visualize a hierarchical structure of the topics.

A ward linkage function is used to perform the
Expand Down Expand Up @@ -3166,7 +3177,7 @@ def visualize_heatmap(
title: str = "<b>Similarity Matrix</b>",
width: int = 800,
height: int = 800,
) -> go.Figure:
) -> "go.Figure":
"""Visualize a heatmap of the topic's similarity matrix.

Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics,
Expand Down Expand Up @@ -3226,7 +3237,7 @@ def visualize_barchart(
width: int = 250,
height: int = 250,
autoscale: bool = False,
) -> go.Figure:
) -> "go.Figure":
"""Visualize a barchart of selected topics.

Arguments:
Expand Down
12 changes: 11 additions & 1 deletion bertopic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable
from scipy.sparse import csr_matrix
from scipy.spatial.distance import squareform
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Any


class MyLogger:
Expand Down Expand Up @@ -226,3 +226,13 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray:
repr_, ctfidf_used = embeddings, False

return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used


class MockPlotlyModule:
"""Mock module that raises an error when plotly functions are called."""

def __getattr__(self, name: str) -> Any:
def mock_function(*args, **kwargs):
raise ImportError(f"Plotly is required to use '{name}'. Install it with uv pip install plotly")

return mock_function
4 changes: 2 additions & 2 deletions docs/getting_started/tips_and_tricks/tips_and_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ embeddings = normalize(embeddings)

The default embedding model in BERTopic is one of the amazing sentence-transformers models, namely `"all-MiniLM-L6-v2"`. Although this model performs well out of the box, it typically needs a GPU to transform the documents into embeddings in a reasonable time. Moreover, the installation requires `pytorch` which often results in a rather large environment, memory-wise.

Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, and/or `HDBSCAN`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:
Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, `HDBSCAN` and/or `plotly`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows:

```bash
pip install --no-deps bertopic
pip install --upgrade numpy pandas scikit-learn tqdm plotly pyyaml
pip install --upgrade numpy pandas scikit-learn tqdm pyyaml
```

This installs a bare-bones version of BERTopic. If you want to use UMAP and Model2Vec for instance, you'll need to first install them:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_other.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from bertopic import BERTopic
from bertopic.dimensionality import BaseDimensionalityReduction

try:
from plotly.graph_objects import Figure
except ImportError:
Figure = None


def test_load_save_model():
Expand All @@ -20,3 +26,20 @@ def test_get_params():
assert params["n_gram_range"] == (1, 1)
assert params["min_topic_size"] == 10
assert params["language"] == "english"


def test_no_plotly():
model = BERTopic(
language="Dutch",
embedding_model=None,
min_topic_size=2,
top_n_words=1,
umap_model=BaseDimensionalityReduction(),
)
model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10)

try:
out = model.visualize_topics()
assert isinstance(out, Figure) if Figure else False
except ImportError as e:
assert "Plotly is required to use" in str(e)
Loading