diff --git a/Makefile b/Makefile index bc7f2ba1..bc9f0cd8 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,13 @@ test: pytest +test-no-plotly: + uv sync --extra test + uv pip uninstall plotly + pytest tests/test_other.py -k plotly --pdb + uv sync --extra test + pytest tests/test_other.py -k plotly + coverage: pytest --cov diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index a59c0798..1ac66a13 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -26,6 +26,7 @@ from collections import defaultdict, Counter from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch +from importlib.util import find_spec # Typing import sys @@ -34,7 +35,21 @@ from typing import Literal else: from typing_extensions import Literal -from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable +from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING + +# Plotting +if find_spec("plotly") is None: + from bertopic._utils import MockPlotlyModule + + plotting = MockPlotlyModule() + +else: + from bertopic import plotting + + if TYPE_CHECKING: + import plotly.graph_objs as go + import matplotlib.figure as fig + # Models try: @@ -53,7 +68,6 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer # BERTopic -from bertopic import plotting from bertopic.cluster import BaseCluster from bertopic.backend import BaseEmbedder from bertopic.representation._mmr import mmr @@ -73,9 +87,6 @@ ) import bertopic._save_utils as save_utils -# Visualization -import plotly.graph_objects as go - logger = MyLogger() logger.configure("WARNING") @@ -2405,7 +2416,7 @@ def visualize_topics( title: str = "Intertopic Distance Map", width: int = 650, height: int = 650, - ) -> go.Figure: + ) -> "go.Figure": """Visualize topics, their sizes, and their corresponding words. This visualization is highly inspired by LDAvis, a great visualization @@ -2463,7 +2474,7 @@ def visualize_documents( title: str = "Documents and Topics", width: int = 1200, height: int = 750, - ) -> go.Figure: + ) -> "go.Figure": """Visualize documents and their topics in 2D. Arguments: @@ -2565,7 +2576,7 @@ def visualize_document_datamap( topic_prefix: bool = False, datamap_kwds: dict = {}, int_datamap_kwds: dict = {}, - ): + ) -> "fig.Figure": """Visualize documents and their topics in 2D as a static plot for publication using DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model. @@ -2676,7 +2687,7 @@ def visualize_hierarchical_documents( title: str = "Hierarchical Documents and Topics", width: int = 1200, height: int = 750, - ) -> go.Figure: + ) -> "go.Figure": """Visualize documents and their topics in 2D at different levels of hierarchy. Arguments: @@ -2788,7 +2799,7 @@ def visualize_term_rank( title: str = "Term score decline per Topic", width: int = 800, height: int = 500, - ) -> go.Figure: + ) -> "go.Figure": """Visualize the ranks of all terms across all topics. Each topic is represented by a set of words. These words, however, @@ -2853,7 +2864,7 @@ def visualize_topics_over_time( title: str = "Topics over Time", width: int = 1250, height: int = 450, - ) -> go.Figure: + ) -> "go.Figure": """Visualize topics over time. Arguments: @@ -2909,7 +2920,7 @@ def visualize_topics_per_class( title: str = "Topics per Class", width: int = 1250, height: int = 900, - ) -> go.Figure: + ) -> "go.Figure": """Visualize topics per class. Arguments: @@ -2963,7 +2974,7 @@ def visualize_distribution( title: str = "Topic Probability Distribution", width: int = 800, height: int = 600, - ) -> go.Figure: + ) -> "go.Figure": """Visualize the distribution of topic probabilities. Arguments: @@ -3070,7 +3081,7 @@ def visualize_hierarchy( linkage_function: Callable[[csr_matrix], np.ndarray] = None, distance_function: Callable[[csr_matrix], csr_matrix] = None, color_threshold: int = 1, - ) -> go.Figure: + ) -> "go.Figure": """Visualize a hierarchical structure of the topics. A ward linkage function is used to perform the @@ -3166,7 +3177,7 @@ def visualize_heatmap( title: str = "Similarity Matrix", width: int = 800, height: int = 800, - ) -> go.Figure: + ) -> "go.Figure": """Visualize a heatmap of the topic's similarity matrix. Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics, @@ -3226,7 +3237,7 @@ def visualize_barchart( width: int = 250, height: int = 250, autoscale: bool = False, - ) -> go.Figure: + ) -> "go.Figure": """Visualize a barchart of selected topics. Arguments: diff --git a/bertopic/_utils.py b/bertopic/_utils.py index 526178d0..035c6acb 100644 --- a/bertopic/_utils.py +++ b/bertopic/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from scipy.sparse import csr_matrix from scipy.spatial.distance import squareform -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Any class MyLogger: @@ -226,3 +226,13 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray: repr_, ctfidf_used = embeddings, False return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used + + +class MockPlotlyModule: + """Mock module that raises an error when plotly functions are called.""" + + def __getattr__(self, name: str) -> Any: + def mock_function(*args, **kwargs): + raise ImportError(f"Plotly is required to use '{name}'. Install it with uv pip install plotly") + + return mock_function diff --git a/docs/getting_started/tips_and_tricks/tips_and_tricks.md b/docs/getting_started/tips_and_tricks/tips_and_tricks.md index 115dcc5e..a6b684bf 100644 --- a/docs/getting_started/tips_and_tricks/tips_and_tricks.md +++ b/docs/getting_started/tips_and_tricks/tips_and_tricks.md @@ -196,11 +196,11 @@ embeddings = normalize(embeddings) The default embedding model in BERTopic is one of the amazing sentence-transformers models, namely `"all-MiniLM-L6-v2"`. Although this model performs well out of the box, it typically needs a GPU to transform the documents into embeddings in a reasonable time. Moreover, the installation requires `pytorch` which often results in a rather large environment, memory-wise. -Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, and/or `HDBSCAN`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows: +Fortunately, it is possible to install BERTopic without `sentence-transformers`, `UMAP`, `HDBSCAN` and/or `plotly`. This can be to reduce your docker images for inference or when you do not use `pytorch` but for instance [Model2Vec](https://github.com/MinishLab/model2vec) instead. The installation can be done as follows: ```bash pip install --no-deps bertopic -pip install --upgrade numpy pandas scikit-learn tqdm plotly pyyaml +pip install --upgrade numpy pandas scikit-learn tqdm pyyaml ``` This installs a bare-bones version of BERTopic. If you want to use UMAP and Model2Vec for instance, you'll need to first install them: diff --git a/tests/test_other.py b/tests/test_other.py index 309c9900..59dea1a7 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -1,4 +1,10 @@ from bertopic import BERTopic +from bertopic.dimensionality import BaseDimensionalityReduction + +try: + from plotly.graph_objects import Figure +except ImportError: + Figure = None def test_load_save_model(): @@ -20,3 +26,20 @@ def test_get_params(): assert params["n_gram_range"] == (1, 1) assert params["min_topic_size"] == 10 assert params["language"] == "english" + + +def test_no_plotly(): + model = BERTopic( + language="Dutch", + embedding_model=None, + min_topic_size=2, + top_n_words=1, + umap_model=BaseDimensionalityReduction(), + ) + model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10) + + try: + out = model.visualize_topics() + assert isinstance(out, Figure) if Figure else False + except ImportError as e: + assert "Plotly is required to use" in str(e)