Skip to content

Commit 7d2aa5b

Browse files
authored
Interactive DataMapPlot (#1853) and deprecate non-chat OpenAI models (#2287)
1 parent f3900ad commit 7d2aa5b

File tree

12 files changed

+648
-119
lines changed

12 files changed

+648
-119
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ venv/
7575
ENV/
7676
env.bak/
7777
venv.bak/
78+
*.lock
7879

7980
# Artifacts
8081
.idea

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ from bertopic.representation import OpenAI
161161

162162
# Fine-tune topic representations with GPT
163163
client = openai.OpenAI(api_key="sk-...")
164-
representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True)
164+
representation_model = OpenAI(client, model="gpt-4o-mini", chat=True)
165165
topic_model = BERTopic(representation_model=representation_model)
166166
```
167167

bertopic/_bertopic.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,38 +2532,51 @@ def visualize_documents(
25322532

25332533
def visualize_document_datamap(
25342534
self,
2535-
docs: List[str],
2535+
docs: List[str] = None,
25362536
topics: List[int] = None,
25372537
embeddings: np.ndarray = None,
25382538
reduced_embeddings: np.ndarray = None,
25392539
custom_labels: Union[bool, str] = False,
25402540
title: str = "Documents and Topics",
25412541
sub_title: Union[str, None] = None,
25422542
width: int = 1200,
2543-
height: int = 1200,
2544-
**datamap_kwds,
2543+
height: int = 750,
2544+
interactive: bool = False,
2545+
enable_search: bool = False,
2546+
topic_prefix: bool = False,
2547+
datamap_kwds: dict = {},
2548+
int_datamap_kwds: dict = {},
25452549
):
25462550
"""Visualize documents and their topics in 2D as a static plot for publication using
25472551
DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best
25482552
to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model.
25492553
25502554
Arguments:
25512555
topic_model: A fitted BERTopic instance.
2552-
docs: The documents you used when calling either `fit` or `fit_transform`
2556+
docs: The documents you used when calling either `fit` or `fit_transform`.
25532557
topics: A selection of topics to visualize.
2554-
Not to be confused with the topics that you get from .fit_transform. For example, if you want to visualize only topics 1 through 5: topics = [1, 2, 3, 4, 5]. Documents not in these topics will be shown as noise points.
2558+
Not to be confused with the topics that you get from `.fit_transform`.
2559+
For example, if you want to visualize only topics 1 through 5:
2560+
`topics = [1, 2, 3, 4, 5]`. Documents not in these topics will be shown
2561+
as noise points.
25552562
embeddings: The embeddings of all documents in `docs`.
25562563
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
25572564
custom_labels: If bool, whether to use custom topic labels that were defined using
2558-
`topic_model.set_topic_labels`.
2559-
If `str`, it uses labels from other aspects, e.g., "Aspect1".
2565+
`topic_model.set_topic_labels`.
2566+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
25602567
title: Title of the plot.
25612568
sub_title: Sub-title of the plot.
25622569
width: The width of the figure.
25632570
height: The height of the figure.
2564-
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
2565-
`create_plot` function. See the DataMapPlot documentation
2566-
for more details.
2571+
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
2572+
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
2573+
topic_prefix: Prefix to add to the topic number when displaying the topic name.
2574+
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
2575+
if you are not using the interactive version.
2576+
See the DataMapPlot documentation for more details.
2577+
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
2578+
if you are using the interactive version.
2579+
See the DataMapPlot documentation for more details.
25672580
25682581
Returns:
25692582
figure: A Matplotlib Figure object.
@@ -2610,7 +2623,6 @@ def visualize_document_datamap(
26102623
```
26112624
"""
26122625
check_is_fitted(self)
2613-
check_documents_type(docs)
26142626
return plotting.visualize_document_datamap(
26152627
self,
26162628
docs,
@@ -2622,7 +2634,11 @@ def visualize_document_datamap(
26222634
sub_title,
26232635
width,
26242636
height,
2625-
**datamap_kwds,
2637+
interactive,
2638+
enable_search,
2639+
topic_prefix,
2640+
datamap_kwds,
2641+
int_datamap_kwds,
26262642
)
26272643

26282644
def visualize_hierarchical_documents(

bertopic/plotting/_datamap.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,27 @@ class Figure(object):
1717

1818
def visualize_document_datamap(
1919
topic_model,
20-
docs: List[str],
20+
docs: List[str] = None,
2121
topics: List[int] = None,
2222
embeddings: np.ndarray = None,
2323
reduced_embeddings: np.ndarray = None,
2424
custom_labels: Union[bool, str] = False,
2525
title: str = "Documents and Topics",
2626
sub_title: Union[str, None] = None,
2727
width: int = 1200,
28-
height: int = 1200,
29-
**datamap_kwds,
28+
height: int = 750,
29+
interactive: bool = False,
30+
enable_search: bool = False,
31+
topic_prefix: bool = False,
32+
datamap_kwds: dict = {},
33+
int_datamap_kwds: dict = {},
3034
) -> Figure:
3135
"""Visualize documents and their topics in 2D as a static plot for publication using
3236
DataMapPlot.
3337
3438
Arguments:
3539
topic_model: A fitted BERTopic instance.
36-
docs: The documents you used when calling either `fit` or `fit_transform`
40+
docs: The documents you used when calling either `fit` or `fit_transform`.
3741
topics: A selection of topics to visualize.
3842
Not to be confused with the topics that you get from `.fit_transform`.
3943
For example, if you want to visualize only topics 1 through 5:
@@ -48,9 +52,15 @@ def visualize_document_datamap(
4852
sub_title: Sub-title of the plot.
4953
width: The width of the figure.
5054
height: The height of the figure.
51-
**datamap_kwds: All further keyword args will be passed on to DataMapPlot's
52-
`create_plot` function. See the DataMapPlot documentation
53-
for more details.
55+
interactive: Whether to create an interactive plot using DataMapPlot's `create_interactive_plot`.
56+
enable_search: Whether to enable search in the interactive plot. Only works if `interactive=True`.
57+
topic_prefix: Prefix to add to the topic number when displaying the topic name.
58+
datamap_kwds: Keyword args be passed on to DataMapPlot's `create_plot` function
59+
if you are not using the interactive version.
60+
See the DataMapPlot documentation for more details.
61+
int_datamap_kwds: Keyword args be passed on to DataMapPlot's `create_interactive_plot` function
62+
if you are using the interactive version.
63+
See the DataMapPlot documentation for more details.
5464
5565
Returns:
5666
figure: A Matplotlib Figure object.
@@ -127,10 +137,13 @@ def visualize_document_datamap(
127137
elif topic_model.custom_labels_ is not None and custom_labels:
128138
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
129139
else:
130-
names = [
131-
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
132-
for topic in unique_topics
133-
]
140+
if topic_prefix:
141+
names = [
142+
f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3])
143+
for topic in unique_topics
144+
]
145+
else:
146+
names = [" ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
134147

135148
topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)}
136149
topic_name_mapping[-1] = "Unlabelled"
@@ -145,14 +158,25 @@ def visualize_document_datamap(
145158
# Map in topic names and plot
146159
named_topic_per_doc = pd.Series(topic_per_doc).map(topic_name_mapping).values
147160

148-
figure, axes = datamapplot.create_plot(
149-
embeddings_2d,
150-
named_topic_per_doc,
151-
figsize=(width / 100, height / 100),
152-
dpi=100,
153-
title=title,
154-
sub_title=sub_title,
155-
**datamap_kwds,
156-
)
161+
if interactive:
162+
figure = datamapplot.create_interactive_plot(
163+
embeddings_2d,
164+
named_topic_per_doc,
165+
hover_text=docs,
166+
enable_search=enable_search,
167+
width=width,
168+
height=height,
169+
**int_datamap_kwds,
170+
)
171+
else:
172+
figure, _ = datamapplot.create_plot(
173+
embeddings_2d,
174+
named_topic_per_doc,
175+
figsize=(width / 100, height / 100),
176+
dpi=100,
177+
title=title,
178+
sub_title=sub_title,
179+
**datamap_kwds,
180+
)
157181

158182
return figure

0 commit comments

Comments
 (0)