|
27 | 27 | from collections import defaultdict, Counter |
28 | 28 | from scipy.sparse import csr_matrix |
29 | 29 | from scipy.cluster import hierarchy as sch |
| 30 | +from importlib.util import find_spec |
30 | 31 |
|
31 | 32 | # Typing |
32 | 33 | import sys |
|
35 | 36 | from typing import Literal |
36 | 37 | else: |
37 | 38 | from typing_extensions import Literal |
38 | | -from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable |
| 39 | +from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable, TYPE_CHECKING |
| 40 | + |
| 41 | +# Plotting |
| 42 | +if find_spec("plotly") is None: |
| 43 | + from bertopic._utils import MockPlotlyModule |
| 44 | + |
| 45 | + plotting = MockPlotlyModule() |
| 46 | + |
| 47 | +else: |
| 48 | + from bertopic import plotting |
| 49 | + |
| 50 | + if TYPE_CHECKING: |
| 51 | + import plotly.graph_objs as go |
| 52 | + import matplotlib.figure as fig |
| 53 | + |
39 | 54 |
|
40 | 55 | # Models |
41 | 56 | try: |
|
54 | 69 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer |
55 | 70 |
|
56 | 71 | # BERTopic |
57 | | -from bertopic import plotting |
58 | 72 | from bertopic.cluster import BaseCluster |
59 | 73 | from bertopic.backend import BaseEmbedder |
60 | 74 | from bertopic.representation._mmr import mmr |
|
74 | 88 | ) |
75 | 89 | import bertopic._save_utils as save_utils |
76 | 90 |
|
77 | | -# Visualization |
78 | | -import plotly.graph_objects as go |
79 | | - |
80 | 91 | logger = MyLogger() |
81 | 92 | logger.configure("WARNING") |
82 | 93 |
|
@@ -2541,7 +2552,7 @@ def visualize_topics( |
2541 | 2552 | title: str = "<b>Intertopic Distance Map</b>", |
2542 | 2553 | width: int = 650, |
2543 | 2554 | height: int = 650, |
2544 | | - ) -> go.Figure: |
| 2555 | + ) -> "go.Figure": |
2545 | 2556 | """Visualize topics, their sizes, and their corresponding words. |
2546 | 2557 |
|
2547 | 2558 | This visualization is highly inspired by LDAvis, a great visualization |
@@ -2599,7 +2610,7 @@ def visualize_documents( |
2599 | 2610 | title: str = "<b>Documents and Topics</b>", |
2600 | 2611 | width: int = 1200, |
2601 | 2612 | height: int = 750, |
2602 | | - ) -> go.Figure: |
| 2613 | + ) -> "go.Figure": |
2603 | 2614 | """Visualize documents and their topics in 2D. |
2604 | 2615 |
|
2605 | 2616 | Arguments: |
@@ -2701,7 +2712,7 @@ def visualize_document_datamap( |
2701 | 2712 | topic_prefix: bool = False, |
2702 | 2713 | datamap_kwds: dict = {}, |
2703 | 2714 | int_datamap_kwds: dict = {}, |
2704 | | - ): |
| 2715 | + ) -> "fig.Figure": |
2705 | 2716 | """Visualize documents and their topics in 2D as a static plot for publication using |
2706 | 2717 | DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best |
2707 | 2718 | to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model. |
@@ -2812,7 +2823,7 @@ def visualize_hierarchical_documents( |
2812 | 2823 | title: str = "<b>Hierarchical Documents and Topics</b>", |
2813 | 2824 | width: int = 1200, |
2814 | 2825 | height: int = 750, |
2815 | | - ) -> go.Figure: |
| 2826 | + ) -> "go.Figure": |
2816 | 2827 | """Visualize documents and their topics in 2D at different levels of hierarchy. |
2817 | 2828 |
|
2818 | 2829 | Arguments: |
@@ -2924,7 +2935,7 @@ def visualize_term_rank( |
2924 | 2935 | title: str = "<b>Term score decline per Topic</b>", |
2925 | 2936 | width: int = 800, |
2926 | 2937 | height: int = 500, |
2927 | | - ) -> go.Figure: |
| 2938 | + ) -> "go.Figure": |
2928 | 2939 | """Visualize the ranks of all terms across all topics. |
2929 | 2940 |
|
2930 | 2941 | Each topic is represented by a set of words. These words, however, |
@@ -2989,7 +3000,7 @@ def visualize_topics_over_time( |
2989 | 3000 | title: str = "<b>Topics over Time</b>", |
2990 | 3001 | width: int = 1250, |
2991 | 3002 | height: int = 450, |
2992 | | - ) -> go.Figure: |
| 3003 | + ) -> "go.Figure": |
2993 | 3004 | """Visualize topics over time. |
2994 | 3005 |
|
2995 | 3006 | Arguments: |
@@ -3045,7 +3056,7 @@ def visualize_topics_per_class( |
3045 | 3056 | title: str = "<b>Topics per Class</b>", |
3046 | 3057 | width: int = 1250, |
3047 | 3058 | height: int = 900, |
3048 | | - ) -> go.Figure: |
| 3059 | + ) -> "go.Figure": |
3049 | 3060 | """Visualize topics per class. |
3050 | 3061 |
|
3051 | 3062 | Arguments: |
@@ -3099,7 +3110,7 @@ def visualize_distribution( |
3099 | 3110 | title: str = "<b>Topic Probability Distribution</b>", |
3100 | 3111 | width: int = 800, |
3101 | 3112 | height: int = 600, |
3102 | | - ) -> go.Figure: |
| 3113 | + ) -> "go.Figure": |
3103 | 3114 | """Visualize the distribution of topic probabilities. |
3104 | 3115 |
|
3105 | 3116 | Arguments: |
@@ -3206,7 +3217,7 @@ def visualize_hierarchy( |
3206 | 3217 | linkage_function: Callable[[csr_matrix], np.ndarray] = None, |
3207 | 3218 | distance_function: Callable[[csr_matrix], csr_matrix] = None, |
3208 | 3219 | color_threshold: int = 1, |
3209 | | - ) -> go.Figure: |
| 3220 | + ) -> "go.Figure": |
3210 | 3221 | """Visualize a hierarchical structure of the topics. |
3211 | 3222 |
|
3212 | 3223 | A ward linkage function is used to perform the |
@@ -3302,7 +3313,7 @@ def visualize_heatmap( |
3302 | 3313 | title: str = "<b>Similarity Matrix</b>", |
3303 | 3314 | width: int = 800, |
3304 | 3315 | height: int = 800, |
3305 | | - ) -> go.Figure: |
| 3316 | + ) -> "go.Figure": |
3306 | 3317 | """Visualize a heatmap of the topic's similarity matrix. |
3307 | 3318 |
|
3308 | 3319 | Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics, |
@@ -3362,7 +3373,7 @@ def visualize_barchart( |
3362 | 3373 | width: int = 250, |
3363 | 3374 | height: int = 250, |
3364 | 3375 | autoscale: bool = False, |
3365 | | - ) -> go.Figure: |
| 3376 | + ) -> "go.Figure": |
3366 | 3377 | """Visualize a barchart of selected topics. |
3367 | 3378 |
|
3368 | 3379 | Arguments: |
|
0 commit comments