Skip to content

Commit ed94ce0

Browse files
committed
simpler type hints when plotly is not installed
1 parent 8e5a7f2 commit ed94ce0

File tree

4 files changed

+40
-55
lines changed

4 files changed

+40
-55
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ test:
44
test-no-plotly:
55
uv sync --extra test
66
uv pip uninstall plotly
7-
pytest tests/test_other.py -k plotly
7+
pytest tests/test_other.py -k plotly --pdb
88
uv sync --extra test
99
pytest tests/test_other.py -k plotly
1010

bertopic/_bertopic.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa: E402
22
import yaml
33
import warnings
4+
import importlib
45

56
warnings.filterwarnings("ignore", category=FutureWarning)
67
warnings.filterwarnings("ignore", category=UserWarning)
@@ -26,6 +27,7 @@
2627
from collections import defaultdict, Counter
2728
from scipy.sparse import csr_matrix
2829
from scipy.cluster import hierarchy as sch
30+
from importlib.util import find_spec
2931

3032
# Typing
3133
import sys
@@ -34,7 +36,19 @@
3436
from typing import Literal
3537
else:
3638
from typing_extensions import Literal
37-
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
39+
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING
40+
41+
# Plotting
42+
if find_spec("plotly") is None:
43+
from bertopic._utils import MockPlotlyModule
44+
plotting = MockPlotlyModule()
45+
46+
else:
47+
from bertopic import plotting
48+
if TYPE_CHECKING:
49+
import plotly.graph_objs as go
50+
import matplotlib.figure as fig
51+
3852

3953
# Models
4054
try:
@@ -72,23 +86,9 @@
7286
)
7387
import bertopic._save_utils as save_utils
7488

75-
7689
logger = MyLogger()
7790
logger.configure("WARNING")
7891

79-
try:
80-
from bertopic import plotting
81-
import plotly.graph_objects as go
82-
83-
except ModuleNotFoundError as e:
84-
if "No module named 'plotly'" in str(e):
85-
logger.warning("Plotly is not installed. Please install it to use the plotting functions.")
86-
from bertopic._utils import mock_plotly_go as go, MockPlotting
87-
88-
plotting = MockPlotting(logger)
89-
else:
90-
raise ModuleNotFoundError(e)
91-
9292

9393
class BERTopic:
9494
"""BERTopic is a topic modeling technique that leverages BERT embeddings and
@@ -2415,7 +2415,7 @@ def visualize_topics(
24152415
title: str = "<b>Intertopic Distance Map</b>",
24162416
width: int = 650,
24172417
height: int = 650,
2418-
) -> go.Figure:
2418+
) -> "go.Figure":
24192419
"""Visualize topics, their sizes, and their corresponding words.
24202420
24212421
This visualization is highly inspired by LDAvis, a great visualization
@@ -2473,7 +2473,7 @@ def visualize_documents(
24732473
title: str = "<b>Documents and Topics</b>",
24742474
width: int = 1200,
24752475
height: int = 750,
2476-
) -> go.Figure:
2476+
) -> "go.Figure":
24772477
"""Visualize documents and their topics in 2D.
24782478
24792479
Arguments:
@@ -2575,7 +2575,7 @@ def visualize_document_datamap(
25752575
topic_prefix: bool = False,
25762576
datamap_kwds: dict = {},
25772577
int_datamap_kwds: dict = {},
2578-
):
2578+
) -> "fig.Figure":
25792579
"""Visualize documents and their topics in 2D as a static plot for publication using
25802580
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
25812581
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
@@ -2686,7 +2686,7 @@ def visualize_hierarchical_documents(
26862686
title: str = "<b>Hierarchical Documents and Topics</b>",
26872687
width: int = 1200,
26882688
height: int = 750,
2689-
) -> go.Figure:
2689+
) -> "go.Figure":
26902690
"""Visualize documents and their topics in 2D at different levels of hierarchy.
26912691
26922692
Arguments:
@@ -2798,7 +2798,7 @@ def visualize_term_rank(
27982798
title: str = "<b>Term score decline per Topic</b>",
27992799
width: int = 800,
28002800
height: int = 500,
2801-
) -> go.Figure:
2801+
) -> "go.Figure":
28022802
"""Visualize the ranks of all terms across all topics.
28032803
28042804
Each topic is represented by a set of words. These words, however,
@@ -2863,7 +2863,7 @@ def visualize_topics_over_time(
28632863
title: str = "<b>Topics over Time</b>",
28642864
width: int = 1250,
28652865
height: int = 450,
2866-
) -> go.Figure:
2866+
) -> "go.Figure":
28672867
"""Visualize topics over time.
28682868
28692869
Arguments:
@@ -2919,7 +2919,7 @@ def visualize_topics_per_class(
29192919
title: str = "<b>Topics per Class</b>",
29202920
width: int = 1250,
29212921
height: int = 900,
2922-
) -> go.Figure:
2922+
) -> "go.Figure":
29232923
"""Visualize topics per class.
29242924
29252925
Arguments:
@@ -2973,7 +2973,7 @@ def visualize_distribution(
29732973
title: str = "<b>Topic Probability Distribution</b>",
29742974
width: int = 800,
29752975
height: int = 600,
2976-
) -> go.Figure:
2976+
) -> "go.Figure":
29772977
"""Visualize the distribution of topic probabilities.
29782978
29792979
Arguments:
@@ -3080,7 +3080,7 @@ def visualize_hierarchy(
30803080
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
30813081
distance_function: Callable[[csr_matrix], csr_matrix] = None,
30823082
color_threshold: int = 1,
3083-
) -> go.Figure:
3083+
) -> "go.Figure":
30843084
"""Visualize a hierarchical structure of the topics.
30853085
30863086
A ward linkage function is used to perform the
@@ -3176,7 +3176,7 @@ def visualize_heatmap(
31763176
title: str = "<b>Similarity Matrix</b>",
31773177
width: int = 800,
31783178
height: int = 800,
3179-
) -> go.Figure:
3179+
) -> "go.Figure":
31803180
"""Visualize a heatmap of the topic's similarity matrix.
31813181
31823182
Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics,
@@ -3236,7 +3236,7 @@ def visualize_barchart(
32363236
width: int = 250,
32373237
height: int = 250,
32383238
autoscale: bool = False,
3239-
) -> go.Figure:
3239+
) -> "go.Figure":
32403240
"""Visualize a barchart of selected topics.
32413241
32423242
Arguments:

bertopic/_utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable
55
from scipy.sparse import csr_matrix
66
from scipy.spatial.distance import squareform
7-
from typing import Optional, Union, Tuple
7+
from typing import Optional, Union, Tuple, Any
88

99

1010
class MyLogger:
@@ -228,26 +228,11 @@ def to_ndarray(array: Union[np.ndarray, csr_matrix]) -> np.ndarray:
228228
return to_ndarray(repr_) if output_ndarray else repr_, ctfidf_used
229229

230230

231-
# Visualization mocks in case plotly is not installed
232-
class MockPlotting:
233-
"""Mock plotting module when plotly is not installed."""
231+
class MockPlotlyModule:
232+
"""Mock module that raises an error when plotly functions are called."""
234233

235-
def __init__(self, logger: MyLogger):
236-
self.logger = logger
237-
238-
def __getattr__(self, name):
234+
def __getattr__(self, name: str) -> Any:
239235
def mock_function(*args, **kwargs):
240-
self.logger.warning(f"Plotly is not installed. Cannot use {name} visualization function.")
241-
return MockFigure()
236+
raise ImportError(f"Plotly is required to use '{name}'. " "Install it with uv pip install plotly")
242237

243238
return mock_function
244-
245-
246-
class MockFigure:
247-
"""Mock class for plotly.graph_objects.Figure when plotly is not installed."""
248-
249-
def __init__(self, *args, **kwargs):
250-
pass
251-
252-
253-
mock_plotly_go = type("MockPlotly", (), {"Figure": MockFigure})()

tests/test_other.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22
from bertopic.dimensionality import BaseDimensionalityReduction
33

44
try:
5-
import plotly.graph_objects as go
6-
7-
figure_type = go.Figure
5+
from plotly.graph_objects import Figure
86
except ImportError:
9-
from bertopic._utils import MockFigure
10-
11-
figure_type = MockFigure
7+
Figure = None
128

139

1410
def test_load_save_model():
@@ -41,5 +37,9 @@ def test_no_plotly():
4137
umap_model=BaseDimensionalityReduction(),
4238
)
4339
model.fit(["hello", "hi", "goodbye", "goodbye", "whats up"] * 10)
44-
out = model.visualize_topics()
45-
assert isinstance(out, figure_type)
40+
41+
try:
42+
out = model.visualize_topics()
43+
assert isinstance(out, Figure) if Figure else False
44+
except ImportError as e:
45+
assert "Plotly is required to use" in str(e)

0 commit comments

Comments
 (0)