|
1 | 1 | # ruff: noqa: E402 |
2 | 2 | import yaml |
3 | 3 | import warnings |
| 4 | +import importlib |
4 | 5 |
|
5 | 6 | warnings.filterwarnings("ignore", category=FutureWarning) |
6 | 7 | warnings.filterwarnings("ignore", category=UserWarning) |
|
26 | 27 | from collections import defaultdict, Counter |
27 | 28 | from scipy.sparse import csr_matrix |
28 | 29 | from scipy.cluster import hierarchy as sch |
| 30 | +from importlib.util import find_spec |
29 | 31 |
|
30 | 32 | # Typing |
31 | 33 | import sys |
|
34 | 36 | from typing import Literal |
35 | 37 | else: |
36 | 38 | from typing_extensions import Literal |
37 | | -from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable |
| 39 | +from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING |
| 40 | + |
| 41 | +# Plotting |
| 42 | +if find_spec("plotly") is None: |
| 43 | + from bertopic._utils import MockPlotlyModule |
| 44 | + plotting = MockPlotlyModule() |
| 45 | + |
| 46 | +else: |
| 47 | + from bertopic import plotting |
| 48 | + if TYPE_CHECKING: |
| 49 | + import plotly.graph_objs as go |
| 50 | + import matplotlib.figure as fig |
| 51 | + |
38 | 52 |
|
39 | 53 | # Models |
40 | 54 | try: |
|
72 | 86 | ) |
73 | 87 | import bertopic._save_utils as save_utils |
74 | 88 |
|
75 | | - |
76 | 89 | logger = MyLogger() |
77 | 90 | logger.configure("WARNING") |
78 | 91 |
|
79 | | -try: |
80 | | - from bertopic import plotting |
81 | | - import plotly.graph_objects as go |
82 | | - |
83 | | -except ModuleNotFoundError as e: |
84 | | - if "No module named 'plotly'" in str(e): |
85 | | - logger.warning("Plotly is not installed. Please install it to use the plotting functions.") |
86 | | - from bertopic._utils import mock_plotly_go as go, MockPlotting |
87 | | - |
88 | | - plotting = MockPlotting(logger) |
89 | | - else: |
90 | | - raise ModuleNotFoundError(e) |
91 | | - |
92 | 92 |
|
93 | 93 | class BERTopic: |
94 | 94 | """BERTopic is a topic modeling technique that leverages BERT embeddings and |
@@ -2415,7 +2415,7 @@ def visualize_topics( |
2415 | 2415 | title: str = "<b>Intertopic Distance Map</b>", |
2416 | 2416 | width: int = 650, |
2417 | 2417 | height: int = 650, |
2418 | | - ) -> go.Figure: |
| 2418 | + ) -> "go.Figure": |
2419 | 2419 | """Visualize topics, their sizes, and their corresponding words. |
2420 | 2420 |
|
2421 | 2421 | This visualization is highly inspired by LDAvis, a great visualization |
@@ -2473,7 +2473,7 @@ def visualize_documents( |
2473 | 2473 | title: str = "<b>Documents and Topics</b>", |
2474 | 2474 | width: int = 1200, |
2475 | 2475 | height: int = 750, |
2476 | | - ) -> go.Figure: |
| 2476 | + ) -> "go.Figure": |
2477 | 2477 | """Visualize documents and their topics in 2D. |
2478 | 2478 |
|
2479 | 2479 | Arguments: |
@@ -2575,7 +2575,7 @@ def visualize_document_datamap( |
2575 | 2575 | topic_prefix: bool = False, |
2576 | 2576 | datamap_kwds: dict = {}, |
2577 | 2577 | int_datamap_kwds: dict = {}, |
2578 | | - ): |
| 2578 | + ) -> "fig.Figure": |
2579 | 2579 | """Visualize documents and their topics in 2D as a static plot for publication using |
2580 | 2580 | DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best |
2581 | 2581 | to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model. |
@@ -2686,7 +2686,7 @@ def visualize_hierarchical_documents( |
2686 | 2686 | title: str = "<b>Hierarchical Documents and Topics</b>", |
2687 | 2687 | width: int = 1200, |
2688 | 2688 | height: int = 750, |
2689 | | - ) -> go.Figure: |
| 2689 | + ) -> "go.Figure": |
2690 | 2690 | """Visualize documents and their topics in 2D at different levels of hierarchy. |
2691 | 2691 |
|
2692 | 2692 | Arguments: |
@@ -2798,7 +2798,7 @@ def visualize_term_rank( |
2798 | 2798 | title: str = "<b>Term score decline per Topic</b>", |
2799 | 2799 | width: int = 800, |
2800 | 2800 | height: int = 500, |
2801 | | - ) -> go.Figure: |
| 2801 | + ) -> "go.Figure": |
2802 | 2802 | """Visualize the ranks of all terms across all topics. |
2803 | 2803 |
|
2804 | 2804 | Each topic is represented by a set of words. These words, however, |
@@ -2863,7 +2863,7 @@ def visualize_topics_over_time( |
2863 | 2863 | title: str = "<b>Topics over Time</b>", |
2864 | 2864 | width: int = 1250, |
2865 | 2865 | height: int = 450, |
2866 | | - ) -> go.Figure: |
| 2866 | + ) -> "go.Figure": |
2867 | 2867 | """Visualize topics over time. |
2868 | 2868 |
|
2869 | 2869 | Arguments: |
@@ -2919,7 +2919,7 @@ def visualize_topics_per_class( |
2919 | 2919 | title: str = "<b>Topics per Class</b>", |
2920 | 2920 | width: int = 1250, |
2921 | 2921 | height: int = 900, |
2922 | | - ) -> go.Figure: |
| 2922 | + ) -> "go.Figure": |
2923 | 2923 | """Visualize topics per class. |
2924 | 2924 |
|
2925 | 2925 | Arguments: |
@@ -2973,7 +2973,7 @@ def visualize_distribution( |
2973 | 2973 | title: str = "<b>Topic Probability Distribution</b>", |
2974 | 2974 | width: int = 800, |
2975 | 2975 | height: int = 600, |
2976 | | - ) -> go.Figure: |
| 2976 | + ) -> "go.Figure": |
2977 | 2977 | """Visualize the distribution of topic probabilities. |
2978 | 2978 |
|
2979 | 2979 | Arguments: |
@@ -3080,7 +3080,7 @@ def visualize_hierarchy( |
3080 | 3080 | linkage_function: Callable[[csr_matrix], np.ndarray] = None, |
3081 | 3081 | distance_function: Callable[[csr_matrix], csr_matrix] = None, |
3082 | 3082 | color_threshold: int = 1, |
3083 | | - ) -> go.Figure: |
| 3083 | + ) -> "go.Figure": |
3084 | 3084 | """Visualize a hierarchical structure of the topics. |
3085 | 3085 |
|
3086 | 3086 | A ward linkage function is used to perform the |
@@ -3176,7 +3176,7 @@ def visualize_heatmap( |
3176 | 3176 | title: str = "<b>Similarity Matrix</b>", |
3177 | 3177 | width: int = 800, |
3178 | 3178 | height: int = 800, |
3179 | | - ) -> go.Figure: |
| 3179 | + ) -> "go.Figure": |
3180 | 3180 | """Visualize a heatmap of the topic's similarity matrix. |
3181 | 3181 |
|
3182 | 3182 | Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics, |
@@ -3236,7 +3236,7 @@ def visualize_barchart( |
3236 | 3236 | width: int = 250, |
3237 | 3237 | height: int = 250, |
3238 | 3238 | autoscale: bool = False, |
3239 | | - ) -> go.Figure: |
| 3239 | + ) -> "go.Figure": |
3240 | 3240 | """Visualize a barchart of selected topics. |
3241 | 3241 |
|
3242 | 3242 | Arguments: |
|
0 commit comments