Skip to content

Commit 234553f

Browse files
authored
update base class to use the new ChatMessageSolver api (#15)
* semver * update parent class * fix
1 parent 99b7198 commit 234553f

File tree

6 files changed

+73
-40
lines changed

6 files changed

+73
-40
lines changed

.github/workflows/build_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- name: Setup Python
1414
uses: actions/setup-python@v1
1515
with:
16-
python-version: 3.8
16+
python-version: "3.11"
1717
- name: Install Build Tools
1818
run: |
1919
python -m pip install build wheel

.github/workflows/license_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Setup Python
1717
uses: actions/setup-python@v1
1818
with:
19-
python-version: 3.8
19+
python-version: "3.11"
2020
- name: Install Build Tools
2121
run: |
2222
python -m pip install build wheel

.github/workflows/publish_stable.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Setup Python
2727
uses: actions/setup-python@v1
2828
with:
29-
python-version: 3.8
29+
python-version: "3.11"
3030
- name: Install Build Tools
3131
run: |
3232
python -m pip install build wheel

.github/workflows/release_workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
- name: Setup Python
4747
uses: actions/setup-python@v1
4848
with:
49-
python-version: 3.8
49+
python-version: "3.11"
5050
- name: Install Build Tools
5151
run: |
5252
python -m pip install build wheel

ovos_solver_openai_persona/__init__.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,11 @@ def __init__(self, config=None):
1212
self.default_persona = config.get("persona") or "helpful, creative, clever, and very friendly."
1313

1414
def get_chat_history(self, persona=None):
15-
qa = self.qa_pairs[-1 * self.max_utts:]
1615
persona = persona or self.default_persona
1716
initial_prompt = f"You are a helpful assistant. " \
1817
f"You give short and factual answers. " \
1918
f"You are {persona}"
20-
messages = [
21-
{"role": "system", "content": initial_prompt},
22-
]
23-
for q, a in qa:
24-
messages.append({"role": "user", "content": q})
25-
messages.append({"role": "assistant", "content": a})
26-
return messages
19+
return super().get_chat_history(initial_prompt)
2720

2821
# officially exported Solver methods
2922
def get_spoken_answer(self, query: str,
@@ -40,13 +33,9 @@ def get_spoken_answer(self, query: str,
4033
Returns:
4134
str: The spoken answer as a text response.
4235
"""
43-
messages = self.get_prompt(query, self.default_persona)
44-
response = self._do_api_request(messages)
45-
answer = response.strip()
36+
answer = super().get_spoken_answer(query, lang, units)
4637
if not answer or not answer.strip("?") or not answer.strip("_"):
4738
return None
48-
if self.memory:
49-
self.qa_pairs.append((query, answer))
5039
return answer
5140

5241

@@ -64,14 +53,14 @@ def get_spoken_answer(self, query: str,
6453

6554

6655
if __name__ == "__main__":
67-
bot = OpenAIPersonaSolver({"key": "sk-xxxx", "api_url": "https://llama.smartgic.io/v1"})
68-
for utt in bot.stream_utterances("describe quantum mechanics in simple terms"):
69-
print(utt)
56+
bot = OpenAIPersonaSolver(LLAMA_DEMO["ovos-solver-openai-persona-plugin"])
57+
#for utt in bot.stream_utterances("describe quantum mechanics in simple terms"):
58+
# print(utt)
7059
# Quantum mechanics is a branch of physics that studies the behavior of atoms and particles at the smallest scales.
7160
# It describes how these particles interact with each other, move, and change energy levels.
7261
# Think of it like playing with toy building blocks that represent particles.
7362
# Instead of rigid structures, these particles can be in different energy levels or "states." Quantum mechanics helps scientists understand and predict these states, making it crucial for many fields like chemistry, materials science, and engineering.
7463

7564
# Quantum mechanics is a branch of physics that deals with the behavior of particles on a very small scale, such as atoms and subatomic particles. It explores the idea that particles can exist in multiple states at once and that their behavior is not predictable in the traditional sense.
76-
print(bot.spoken_answer("Quem encontrou o caminho maritimo para o Brazil", lang="pt-pt"))
65+
print(bot.spoken_answer("what is the definition of computer", lang="en-US"))
7766
# O português Pedro Álvares Cabral encontrou o caminho marítimo para o Brasil em 1500. Ele foi o responsável por descobrir o litoral brasileiro, embora Cristóvão Colombo tenha chegado à América do Sul em 1498, cinco anos antes. Cabral desembarcou na atual costa de Alagoas, no Nordeste do Brasil.

ovos_solver_openai_persona/engines.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import json
2-
from typing import Optional, Iterable
2+
from typing import Optional, Iterable, List, Dict
33

44
import requests
55
from ovos_plugin_manager.templates.solvers import QuestionSolver
66
from ovos_utils.log import LOG
77

8+
from ovos_plugin_manager.templates.solvers import ChatMessageSolver
9+
10+
MessageList = List[Dict[str, str]] # for typing
811

912
class OpenAICompletionsSolver(QuestionSolver):
1013
enable_tx = False
@@ -69,14 +72,14 @@ def get_spoken_answer(self, query: str,
6972
return answer
7073

7174

72-
class OpenAIChatCompletionsSolver(QuestionSolver):
75+
class OpenAIChatCompletionsSolver(ChatMessageSolver):
7376
enable_tx = False
7477
priority = 25
7578

7679
def __init__(self, config=None):
7780
super().__init__(config)
7881
self.api_url = f"{self.config.get('api_url', 'https://api.openai.com/v1')}/chat/completions"
79-
self.engine = self.config.get("model", "gpt-3.5-turbo") # "ada" cheaper and faster, "davinci" better
82+
self.engine = self.config.get("model", "gpt-4o-mini") # "ada" cheaper and faster, "davinci" better
8083
self.stop_token = "<|im_end|>"
8184
self.key = self.config.get("key")
8285
if not self.key:
@@ -154,7 +157,7 @@ def _do_streaming_api_request(self, messages):
154157

155158
def get_chat_history(self, initial_prompt=None):
156159
qa = self.qa_pairs[-1 * self.max_utts:]
157-
initial_prompt = self.initial_prompt or "You are a helpful assistant."
160+
initial_prompt = initial_prompt or self.initial_prompt or "You are a helpful assistant."
158161
messages = [
159162
{"role": "system", "content": initial_prompt},
160163
]
@@ -163,37 +166,82 @@ def get_chat_history(self, initial_prompt=None):
163166
messages.append({"role": "assistant", "content": a})
164167
return messages
165168

166-
def get_prompt(self, utt, initial_prompt=None):
169+
def get_messages(self, utt, initial_prompt=None) -> MessageList:
167170
messages = self.get_chat_history(initial_prompt)
168171
messages.append({"role": "user", "content": utt})
169172
return messages
170173

171174
# asbtract Solver methods
172-
def stream_utterances(self, query: str,
173-
lang: Optional[str] = None,
174-
units: Optional[str] = None) -> Iterable[str]:
175+
def continue_chat(self, messages: MessageList,
176+
lang: Optional[str],
177+
units: Optional[str] = None) -> Optional[str]:
178+
"""Generate a response based on the chat history.
179+
180+
Args:
181+
messages (List[Dict[str, str]]): List of chat messages, each containing 'role' and 'content'.
182+
lang (Optional[str]): The language code for the response. If None, will be auto-detected.
183+
units (Optional[str]): Optional unit system for numerical values.
184+
185+
Returns:
186+
Optional[str]: The generated response or None if no response could be generated.
175187
"""
176-
Stream utterances for the given query as they become available.
188+
response = self._do_api_request(messages)
189+
answer = response.strip()
190+
if not answer or not answer.strip("?") or not answer.strip("_"):
191+
return None
192+
if self.memory:
193+
query = messages[-1]["content"]
194+
self.qa_pairs.append((query, answer))
195+
return answer
196+
197+
def stream_chat_utterances(self, messages: List[Dict[str, str]],
198+
lang: Optional[str] = None,
199+
units: Optional[str] = None) -> Iterable[str]:
200+
"""
201+
Stream utterances for the given chat history as they become available.
177202
178203
Args:
179-
query (str): The query text.
204+
messages: The chat messages.
180205
lang (Optional[str]): Optional language code. Defaults to None.
181206
units (Optional[str]): Optional units for the query. Defaults to None.
182207
183208
Returns:
184209
Iterable[str]: An iterable of utterances.
185210
"""
186-
messages = self.get_prompt(query)
187211
answer = ""
212+
query = messages[-1]["content"]
213+
if self.memory:
214+
self.qa_pairs.append((query, answer))
215+
188216
for chunk in self._do_streaming_api_request(messages):
189217
answer += chunk
190218
if any(chunk.endswith(p) for p in [".", "!", "?", "\n", ":"]):
191219
if len(chunk) >= 2 and chunk[-2].isdigit() and chunk[-1] == ".":
192220
continue # dont split numbers
193221
if answer.strip():
222+
if self.memory:
223+
full_ans = f"{self.qa_pairs[-1][-1]}\n{answer}".strip()
224+
self.qa_pairs[-1] = (query, full_ans)
194225
yield answer
195226
answer = ""
196227

228+
def stream_utterances(self, query: str,
229+
lang: Optional[str] = None,
230+
units: Optional[str] = None) -> Iterable[str]:
231+
"""
232+
Stream utterances for the given query as they become available.
233+
234+
Args:
235+
query (str): The query text.
236+
lang (Optional[str]): Optional language code. Defaults to None.
237+
units (Optional[str]): Optional units for the query. Defaults to None.
238+
239+
Returns:
240+
Iterable[str]: An iterable of utterances.
241+
"""
242+
messages = self.get_messages(query)
243+
yield from self.stream_chat_utterances(messages, lang, units)
244+
197245
def get_spoken_answer(self, query: str,
198246
lang: Optional[str] = None,
199247
units: Optional[str] = None) -> Optional[str]:
@@ -208,11 +256,7 @@ def get_spoken_answer(self, query: str,
208256
Returns:
209257
str: The spoken answer as a text response.
210258
"""
211-
messages = self.get_prompt(query)
212-
response = self._do_api_request(messages)
213-
answer = response.strip()
214-
if not answer or not answer.strip("?") or not answer.strip("_"):
215-
return None
216-
if self.memory:
217-
self.qa_pairs.append((query, answer))
218-
return answer
259+
messages = self.get_messages(query)
260+
# just for api compat since it's a subclass, shouldn't be directly used
261+
return self.continue_chat(messages=messages, lang=lang, units=units)
262+

0 commit comments

Comments
 (0)