Skip to content

Commit d665d3f

Browse files
authored
v0.14.1 - ChatGPT support and improved Prompting (#1057)
1 parent 5e63dac commit d665d3f

File tree

7 files changed

+308
-67
lines changed

7 files changed

+308
-67
lines changed

bertopic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from bertopic._bertopic import BERTopic
22

3-
__version__ = "0.14.0"
3+
__version__ = "0.14.1"
44

55
__all__ = [
66
"BERTopic",

bertopic/_bertopic.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,8 @@ def reduce_outliers(self,
20052005
def visualize_topics(self,
20062006
topics: List[int] = None,
20072007
top_n_topics: int = None,
2008+
custom_labels: bool = False,
2009+
title: str = "<b>Intertopic Distance Map</b>",
20082010
width: int = 650,
20092011
height: int = 650) -> go.Figure:
20102012
""" Visualize topics, their sizes, and their corresponding words
@@ -2015,6 +2017,9 @@ def visualize_topics(self,
20152017
Arguments:
20162018
topics: A selection of topics to visualize
20172019
top_n_topics: Only select the top n most frequent topics
2020+
custom_labels: Whether to use custom topic labels that were defined using
2021+
`topic_model.set_topic_labels`.
2022+
title: Title of the plot.
20182023
width: The width of the figure.
20192024
height: The height of the figure.
20202025
@@ -2037,6 +2042,8 @@ def visualize_topics(self,
20372042
return plotting.visualize_topics(self,
20382043
topics=topics,
20392044
top_n_topics=top_n_topics,
2045+
custom_labels=custom_labels,
2046+
title=title,
20402047
width=width,
20412048
height=height)
20422049

@@ -2049,6 +2056,7 @@ def visualize_documents(self,
20492056
hide_annotations: bool = False,
20502057
hide_document_hover: bool = False,
20512058
custom_labels: bool = False,
2059+
title: str = "<b>Documents and Topics</b>",
20522060
width: int = 1200,
20532061
height: int = 750) -> go.Figure:
20542062
""" Visualize documents and their topics in 2D
@@ -2071,6 +2079,7 @@ def visualize_documents(self,
20712079
specific points. Helps to speed up generation of visualization.
20722080
custom_labels: Whether to use custom topic labels that were defined using
20732081
`topic_model.set_topic_labels`.
2082+
title: Title of the plot.
20742083
width: The width of the figure.
20752084
height: The height of the figure.
20762085
@@ -2129,6 +2138,7 @@ def visualize_documents(self,
21292138
hide_annotations=hide_annotations,
21302139
hide_document_hover=hide_document_hover,
21312140
custom_labels=custom_labels,
2141+
title=title,
21322142
width=width,
21332143
height=height)
21342144

@@ -2143,6 +2153,7 @@ def visualize_hierarchical_documents(self,
21432153
hide_document_hover: bool = True,
21442154
nr_levels: int = 10,
21452155
custom_labels: bool = False,
2156+
title: str = "<b>Hierarchical Documents and Topics</b>",
21462157
width: int = 1200,
21472158
height: int = 750) -> go.Figure:
21482159
""" Visualize documents and their topics in 2D at different levels of hierarchy
@@ -2174,6 +2185,7 @@ def visualize_hierarchical_documents(self,
21742185
`topic_model.set_topic_labels`.
21752186
NOTE: Custom labels are only generated for the original
21762187
un-merged topics.
2188+
title: Title of the plot.
21772189
width: The width of the figure.
21782190
height: The height of the figure.
21792191
@@ -2235,13 +2247,15 @@ def visualize_hierarchical_documents(self,
22352247
hide_document_hover=hide_document_hover,
22362248
nr_levels=nr_levels,
22372249
custom_labels=custom_labels,
2250+
title=title,
22382251
width=width,
22392252
height=height)
22402253

22412254
def visualize_term_rank(self,
22422255
topics: List[int] = None,
22432256
log_scale: bool = False,
22442257
custom_labels: bool = False,
2258+
title: str = "<b>Term score decline per Topic</b>",
22452259
width: int = 800,
22462260
height: int = 500) -> go.Figure:
22472261
""" Visualize the ranks of all terms across all topics
@@ -2257,6 +2271,7 @@ def visualize_term_rank(self,
22572271
log_scale: Whether to represent the ranking on a log scale
22582272
custom_labels: Whether to use custom topic labels that were defined using
22592273
`topic_model.set_topic_labels`.
2274+
title: Title of the plot.
22602275
width: The width of the figure.
22612276
height: The height of the figure.
22622277
@@ -2292,6 +2307,7 @@ def visualize_term_rank(self,
22922307
topics=topics,
22932308
log_scale=log_scale,
22942309
custom_labels=custom_labels,
2310+
title=title,
22952311
width=width,
22962312
height=height)
22972313

@@ -2301,6 +2317,7 @@ def visualize_topics_over_time(self,
23012317
topics: List[int] = None,
23022318
normalize_frequency: bool = False,
23032319
custom_labels: bool = False,
2320+
title: str = "<b>Topics over Time</b>",
23042321
width: int = 1250,
23052322
height: int = 450) -> go.Figure:
23062323
""" Visualize topics over time
@@ -2313,6 +2330,7 @@ def visualize_topics_over_time(self,
23132330
normalize_frequency: Whether to normalize each topic's frequency individually
23142331
custom_labels: Whether to use custom topic labels that were defined using
23152332
`topic_model.set_topic_labels`.
2333+
title: Title of the plot.
23162334
width: The width of the figure.
23172335
height: The height of the figure.
23182336
@@ -2342,6 +2360,7 @@ def visualize_topics_over_time(self,
23422360
topics=topics,
23432361
normalize_frequency=normalize_frequency,
23442362
custom_labels=custom_labels,
2363+
title=title,
23452364
width=width,
23462365
height=height)
23472366

@@ -2351,6 +2370,7 @@ def visualize_topics_per_class(self,
23512370
topics: List[int] = None,
23522371
normalize_frequency: bool = False,
23532372
custom_labels: bool = False,
2373+
title: str = "<b>Topics per Class</b>",
23542374
width: int = 1250,
23552375
height: int = 900) -> go.Figure:
23562376
""" Visualize topics per class
@@ -2363,6 +2383,7 @@ def visualize_topics_per_class(self,
23632383
normalize_frequency: Whether to normalize each topic's frequency individually
23642384
custom_labels: Whether to use custom topic labels that were defined using
23652385
`topic_model.set_topic_labels`.
2386+
title: Title of the plot.
23662387
width: The width of the figure.
23672388
height: The height of the figure.
23682389
@@ -2392,13 +2413,15 @@ def visualize_topics_per_class(self,
23922413
topics=topics,
23932414
normalize_frequency=normalize_frequency,
23942415
custom_labels=custom_labels,
2416+
title=title,
23952417
width=width,
23962418
height=height)
23972419

23982420
def visualize_distribution(self,
23992421
probabilities: np.ndarray,
24002422
min_probability: float = 0.015,
24012423
custom_labels: bool = False,
2424+
title: str = "<b>Topic Probability Distribution</b>",
24022425
width: int = 800,
24032426
height: int = 600) -> go.Figure:
24042427
""" Visualize the distribution of topic probabilities
@@ -2409,6 +2432,7 @@ def visualize_distribution(self,
24092432
All others are ignored.
24102433
custom_labels: Whether to use custom topic labels that were defined using
24112434
`topic_model.set_topic_labels`.
2435+
title: Title of the plot.
24122436
width: The width of the figure.
24132437
height: The height of the figure.
24142438
@@ -2433,6 +2457,7 @@ def visualize_distribution(self,
24332457
probabilities=probabilities,
24342458
min_probability=min_probability,
24352459
custom_labels=custom_labels,
2460+
title=title,
24362461
width=width,
24372462
height=height)
24382463

@@ -2492,6 +2517,7 @@ def visualize_hierarchy(self,
24922517
topics: List[int] = None,
24932518
top_n_topics: int = None,
24942519
custom_labels: bool = False,
2520+
title: str = "<b>Hierarchical Clustering</b>",
24952521
width: int = 1000,
24962522
height: int = 600,
24972523
hierarchical_topics: pd.DataFrame = None,
@@ -2514,6 +2540,7 @@ def visualize_hierarchy(self,
25142540
`topic_model.set_topic_labels`.
25152541
NOTE: Custom labels are only generated for the original
25162542
un-merged topics.
2543+
title: Title of the plot.
25172544
width: The width of the figure. Only works if orientation is set to 'left'
25182545
height: The height of the figure. Only works if orientation is set to 'bottom'
25192546
hierarchical_topics: A dataframe that contains a hierarchy of topics
@@ -2570,6 +2597,7 @@ def visualize_hierarchy(self,
25702597
topics=topics,
25712598
top_n_topics=top_n_topics,
25722599
custom_labels=custom_labels,
2600+
title=title,
25732601
width=width,
25742602
height=height,
25752603
hierarchical_topics=hierarchical_topics,
@@ -2583,6 +2611,7 @@ def visualize_heatmap(self,
25832611
top_n_topics: int = None,
25842612
n_clusters: int = None,
25852613
custom_labels: bool = False,
2614+
title: str = "<b>Similarity Matrix</b>",
25862615
width: int = 800,
25872616
height: int = 800) -> go.Figure:
25882617
""" Visualize a heatmap of the topic's similarity matrix
@@ -2597,6 +2626,7 @@ def visualize_heatmap(self,
25972626
matrix by those clusters.
25982627
custom_labels: Whether to use custom topic labels that were defined using
25992628
`topic_model.set_topic_labels`.
2629+
title: Title of the plot.
26002630
width: The width of the figure.
26012631
height: The height of the figure.
26022632
@@ -2625,6 +2655,7 @@ def visualize_heatmap(self,
26252655
top_n_topics=top_n_topics,
26262656
n_clusters=n_clusters,
26272657
custom_labels=custom_labels,
2658+
title=title,
26282659
width=width,
26292660
height=height)
26302661

@@ -3333,9 +3364,9 @@ def _map_probabilities(self,
33333364

33343365
# Map array of probabilities (probability for assigned topic per document)
33353366
if probabilities is not None:
3336-
if len(probabilities.shape) == 2 and self.get_topic(-1):
3367+
if len(probabilities.shape) == 2:
33373368
mapped_probabilities = np.zeros((probabilities.shape[0],
3338-
len(set(mappings.values())) - 1))
3369+
len(set(mappings.values())) - self._outliers))
33393370
for from_topic, to_topic in mappings.items():
33403371
if to_topic != -1 and from_topic != -1:
33413372
mapped_probabilities[:, to_topic] += probabilities[:, from_topic]

bertopic/representation/_cohere.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import numpy as np
1+
import time
22
import pandas as pd
33
from scipy.sparse import csr_matrix
4-
from typing import Mapping, List, Tuple, Union
5-
from sklearn.metrics.pairwise import cosine_similarity
4+
from typing import Mapping, List, Tuple
65
from bertopic.representation._base import BaseRepresentation
76

87

@@ -28,7 +27,11 @@
2827
Keywords: deliver weeks product shipping long delivery received arrived arrive week
2928
Topic name: Shipping and delivery issues
3029
---
31-
"""
30+
Topic:
31+
Sample texts from this topic:
32+
[DOCUMENTS]
33+
Keywords: [KEYWORDS]
34+
Topic name:"""
3235

3336

3437
class Cohere(BaseRepresentation):
@@ -46,6 +49,8 @@ class Cohere(BaseRepresentation):
4649
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
4750
to decide where the keywords and documents need to be
4851
inserted.
52+
delay_in_seconds: The delay in seconds between consecutive prompts
53+
in order to prevent RateLimitErrors.
4954
5055
Usage:
5156
@@ -79,11 +84,13 @@ def __init__(self,
7984
client,
8085
model: str = "xlarge",
8186
prompt: str = None,
87+
delay_in_seconds: float = None,
8288
):
8389
self.client = client
8490
self.model = model
8591
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
8692
self.default_prompt_ = DEFAULT_PROMPT
93+
self.delay_in_seconds = delay_in_seconds
8794

8895
def extract_topics(self,
8996
topic_model,
@@ -109,6 +116,11 @@ def extract_topics(self,
109116
updated_topics = {}
110117
for topic, docs in repr_docs_mappings.items():
111118
prompt = self._create_prompt(docs, topic, topics)
119+
120+
# Delay
121+
if self.delay_in_seconds:
122+
time.sleep(self.delay_in_seconds)
123+
112124
request = self.client.generate(model=self.model,
113125
prompt=prompt,
114126
max_tokens=50,
@@ -118,26 +130,30 @@ def extract_topics(self,
118130
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
119131

120132
return updated_topics
121-
133+
122134
def _create_prompt(self, docs, topic, topics):
123135
keywords = list(zip(*topics[topic]))[0]
124136

125-
# Use a prompt that leverages either keywords or documents in
126-
# a custom location
127-
prompt = ""
128-
if "[KEYWORDS]" in self.prompt:
129-
prompt += self.prompt.replace("[KEYWORDS]", keywords)
130-
if "[DOCUMENTS]" in self.prompt:
131-
to_replace = ""
132-
for doc in docs:
133-
to_replace += f"- {doc[:255]}\n"
134-
prompt += self.prompt.replace("[DOCUMENTS]", to_replace)
135-
136-
# Use the default prompt
137-
if "[KEYWORDS]" and "[DOCUMENTS]" not in self.prompt:
138-
prompt = self.prompt + 'Topic:\nSample texts from this topic:\n'
139-
for doc in docs:
140-
prompt += f"- {doc[:255]}\n"
141-
prompt += "Keywords: " + " ".join(keywords)
142-
prompt += "\nTopic name:"
137+
# Use the Default Chat Prompt
138+
if self.prompt == self.prompt == DEFAULT_PROMPT:
139+
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
140+
prompt = self._replace_documents(prompt, docs)
141+
142+
# Use a custom prompt that leverages keywords, documents or both using
143+
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
144+
else:
145+
prompt = self.prompt
146+
if "[KEYWORDS]" in prompt:
147+
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
148+
if "[DOCUMENTS]" in prompt:
149+
prompt = self._replace_documents(prompt, docs)
150+
151+
return prompt
152+
153+
@staticmethod
154+
def _replace_documents(prompt, docs):
155+
to_replace = ""
156+
for doc in docs:
157+
to_replace += f"- {doc[:255]}\n"
158+
prompt = prompt.replace("[DOCUMENTS]", to_replace)
143159
return prompt

0 commit comments

Comments
 (0)