Skip to content

Commit 5cad563

Browse files
authored
Added system prompts (#2145)
1 parent 641b3d0 commit 5cad563

File tree

3 files changed

+79
-29
lines changed

3 files changed

+79
-29
lines changed

bertopic/representation/_cohere.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
Keywords: [KEYWORDS]
3636
Topic name:"""
3737

38+
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
39+
3840

3941
class Cohere(BaseRepresentation):
4042
"""Use the Cohere API to generate topic labels based on their
@@ -51,6 +53,8 @@ class Cohere(BaseRepresentation):
5153
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
5254
to decide where the keywords and documents need to be
5355
inserted.
56+
system_prompt: The system prompt to be used in the model. If no system prompt is given,
57+
`self.default_system_prompt_` is used instead.
5458
delay_in_seconds: The delay in seconds between consecutive prompts
5559
in order to prevent RateLimitErrors.
5660
nr_docs: The number of documents to pass to OpenAI if a prompt
@@ -107,8 +111,9 @@ class Cohere(BaseRepresentation):
107111
def __init__(
108112
self,
109113
client,
110-
model: str = "xlarge",
114+
model: str = "command-r",
111115
prompt: str = None,
116+
system_prompt: str = None,
112117
delay_in_seconds: float = None,
113118
nr_docs: int = 4,
114119
diversity: float = None,
@@ -118,7 +123,9 @@ def __init__(
118123
self.client = client
119124
self.model = model
120125
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
126+
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
121127
self.default_prompt_ = DEFAULT_PROMPT
128+
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
122129
self.delay_in_seconds = delay_in_seconds
123130
self.nr_docs = nr_docs
124131
self.diversity = diversity
@@ -162,14 +169,14 @@ def extract_topics(
162169
if self.delay_in_seconds:
163170
time.sleep(self.delay_in_seconds)
164171

165-
request = self.client.generate(
172+
request = self.client.chat(
166173
model=self.model,
167-
prompt=prompt,
174+
preamble=self.system_prompt,
175+
message=prompt,
168176
max_tokens=50,
169-
num_generations=1,
170177
stop_sequences=["\n"],
171178
)
172-
label = request.generations[0].text.strip()
179+
label = request.text.strip()
173180
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
174181

175182
return updated_topics

bertopic/representation/_llamacpp.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,34 @@
88

99

1010
DEFAULT_PROMPT = """
11-
Q: I have a topic that contains the following documents:
11+
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
12+
---
13+
Topic:
14+
Sample texts from this topic:
15+
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
16+
- Meat, but especially beef, is the word food in terms of emissions.
17+
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
18+
19+
Keywords: meat beef eat eating emissions steak food health processed chicken
20+
Topic name: Environmental impacts of eating meat
21+
---
22+
Topic:
23+
Sample texts from this topic:
24+
- I have ordered the product weeks ago but it still has not arrived!
25+
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
26+
- I got a message stating that I received the monitor but that is not true!
27+
- It took a month longer to deliver than was advised...
28+
29+
Keywords: deliver weeks product shipping long delivery received arrived arrive week
30+
Topic name: Shipping and delivery issues
31+
---
32+
Topic:
33+
Sample texts from this topic:
1234
[DOCUMENTS]
35+
Keywords: [KEYWORDS]
36+
Topic name:"""
1337

14-
The topic is described by the following keywords: '[KEYWORDS]'.
15-
16-
Based on the above information, can you give a short label of the topic?
17-
A: """
38+
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
1839

1940

2041
class LlamaCPP(BaseRepresentation):
@@ -28,6 +49,8 @@ class LlamaCPP(BaseRepresentation):
2849
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
2950
to decide where the keywords and documents need to be
3051
inserted.
52+
system_prompt: The system prompt to be used in the model. If no system prompt is given,
53+
`self.default_system_prompt_` is used instead.
3154
pipeline_kwargs: Kwargs that you can pass to the `llama_cpp.Llama`
3255
when it is called such as `max_tokens` to be generated.
3356
nr_docs: The number of documents to pass to OpenAI if a prompt
@@ -93,14 +116,15 @@ def __init__(
93116
self,
94117
model: Union[str, Llama],
95118
prompt: str = None,
119+
system_prompt: str = None,
96120
pipeline_kwargs: Mapping[str, Any] = {},
97121
nr_docs: int = 4,
98122
diversity: float = None,
99123
doc_length: int = None,
100124
tokenizer: Union[str, Callable] = None,
101125
):
102126
if isinstance(model, str):
103-
self.model = Llama(model_path=model, n_gpu_layers=-1, stop="Q:")
127+
self.model = Llama(model_path=model, n_gpu_layers=-1, stop="\n", chat_format="ChatML")
104128
elif isinstance(model, Llama):
105129
self.model = model
106130
else:
@@ -110,7 +134,9 @@ def __init__(
110134
"local LLM or a ` llama_cpp.Llama` object."
111135
)
112136
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
137+
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
113138
self.default_prompt_ = DEFAULT_PROMPT
139+
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
114140
self.pipeline_kwargs = pipeline_kwargs
115141
self.nr_docs = nr_docs
116142
self.diversity = diversity
@@ -151,33 +177,39 @@ def extract_topics(
151177
self.prompts_.append(prompt)
152178

153179
# Extract result from generator and use that as label
154-
topic_description = self.model(prompt, **self.pipeline_kwargs)["choices"]
155-
topic_description = [(description["text"].replace(prompt, ""), 1) for description in topic_description]
156-
157-
if len(topic_description) < 10:
158-
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
159-
160-
updated_topics[topic] = topic_description
180+
# topic_description = self.model(prompt, **self.pipeline_kwargs)["choices"]
181+
topic_description = self.model.create_chat_completion(
182+
messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}],
183+
**self.pipeline_kwargs,
184+
)
185+
label = topic_description["choices"][0]["message"]["content"].strip()
186+
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
161187

162188
return updated_topics
163189

164190
def _create_prompt(self, docs, topic, topics):
165-
keywords = ", ".join(list(zip(*topics[topic]))[0])
191+
keywords = list(zip(*topics[topic]))[0]
166192

167-
# Use the default prompt and replace keywords
193+
# Use the Default Chat Prompt
168194
if self.prompt == DEFAULT_PROMPT:
169-
prompt = self.prompt.replace("[KEYWORDS]", keywords)
195+
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
196+
prompt = self._replace_documents(prompt, docs)
170197

171-
# Use a prompt that leverages either keywords or documents in
172-
# a custom location
198+
# Use a custom prompt that leverages keywords, documents or both using
199+
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
173200
else:
174201
prompt = self.prompt
175202
if "[KEYWORDS]" in prompt:
176-
prompt = prompt.replace("[KEYWORDS]", keywords)
203+
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
177204
if "[DOCUMENTS]" in prompt:
178-
to_replace = ""
179-
for doc in docs:
180-
to_replace += f"- {doc}\n"
181-
prompt = prompt.replace("[DOCUMENTS]", to_replace)
205+
prompt = self._replace_documents(prompt, docs)
206+
207+
return prompt
182208

209+
@staticmethod
210+
def _replace_documents(prompt, docs):
211+
to_replace = ""
212+
for doc in docs:
213+
to_replace += f"- {doc}\n"
214+
prompt = prompt.replace("[DOCUMENTS]", to_replace)
183215
return prompt

bertopic/representation/_openai.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
topic: <topic label>
5050
"""
5151

52+
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
53+
5254

5355
class OpenAI(BaseRepresentation):
5456
r"""Using the OpenAI API to generate topic labels based
@@ -74,6 +76,8 @@ class OpenAI(BaseRepresentation):
7476
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
7577
to decide where the keywords and documents need to be
7678
inserted.
79+
system_prompt: The system prompt to be used in the model. If no system prompt is given,
80+
`self.default_system_prompt_` is used instead.
7781
delay_in_seconds: The delay in seconds between consecutive prompts
7882
in order to prevent RateLimitErrors.
7983
exponential_backoff: Retry requests with a random exponential backoff.
@@ -145,6 +149,7 @@ def __init__(
145149
client,
146150
model: str = "text-embedding-3-small",
147151
prompt: str = None,
152+
system_prompt: str = None,
148153
generator_kwargs: Mapping[str, Any] = {},
149154
delay_in_seconds: float = None,
150155
exponential_backoff: bool = False,
@@ -162,7 +167,13 @@ def __init__(
162167
else:
163168
self.prompt = prompt
164169

170+
if chat and system_prompt is None:
171+
self.system_prompt = DEFAULT_SYSTEM_PROMPT
172+
else:
173+
self.system_prompt = system_prompt
174+
165175
self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
176+
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
166177
self.delay_in_seconds = delay_in_seconds
167178
self.exponential_backoff = exponential_backoff
168179
self.chat = chat
@@ -219,7 +230,7 @@ def extract_topics(
219230

220231
if self.chat:
221232
messages = [
222-
{"role": "system", "content": "You are a helpful assistant."},
233+
{"role": "system", "content": self.system_prompt},
223234
{"role": "user", "content": prompt},
224235
]
225236
kwargs = {

0 commit comments

Comments
 (0)