Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 11 additions & 41 deletions ovos_solver_openai_persona/__init__.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,29 @@
from typing import Optional

import warnings
from ovos_solver_openai_persona.engines import OpenAIChatCompletionsSolver


class OpenAIPersonaSolver(OpenAIChatCompletionsSolver):
"""default "Persona" engine"""

def __init__(self, config=None):
# defaults to gpt-3.5-turbo
super().__init__(config=config)
self.default_persona = config.get("persona") or "helpful, creative, clever, and very friendly."

def get_chat_history(self, persona=None):
persona = persona or self.default_persona
initial_prompt = f"You are a helpful assistant. " \
f"You give short and factual answers. " \
f"You are {persona}"
return super().get_chat_history(initial_prompt)

# officially exported Solver methods
def get_spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
"""
Obtain the spoken answer for a given query.

Args:
query (str): The query text.
lang (Optional[str]): Optional language code. Defaults to None.
units (Optional[str]): Optional units for the query. Defaults to None.

Returns:
str: The spoken answer as a text response.
"""
answer = super().get_spoken_answer(query, lang, units)
if not answer or not answer.strip("?") or not answer.strip("_"):
return None
return answer

def __init__(self, *args, **kwargs):
warnings.warn(
"use OpenAIChatCompletionsSolver instead",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)

# for ovos-persona
LLAMA_DEMO = {
"name": "Remote LLama",
"solvers": [
"ovos-solver-openai-plugin"
"ovos-solver-openai-persona-plugin"
],
"ovos-solver-openai-plugin": {
"ovos-solver-openai-persona-plugin": {
"api_url": "https://llama.smartgic.io/v1",
"key": "sk-xxxx"
}
}


if __name__ == "__main__":
bot = OpenAIPersonaSolver(LLAMA_DEMO["ovos-solver-openai-plugin"])
bot = OpenAIChatCompletionsSolver(LLAMA_DEMO["ovos-solver-openai-persona-plugin"])
#for utt in bot.stream_utterances("describe quantum mechanics in simple terms"):
# print(utt)
# Quantum mechanics is a branch of physics that studies the behavior of atoms and particles at the smallest scales.
Expand Down
2 changes: 1 addition & 1 deletion ovos_solver_openai_persona/dialog_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, name="ovos-dialog-transformer-openai-plugin", priority=10, co
"key": self.config.get("key"),
'api_url': self.config.get('api_url', 'https://api.openai.com/v1'),
"enable_memory": False,
"initial_prompt": "your task is to rewrite text as if it was spoken by a different character"
"system_prompt": self.config.get("system_prompt") or "Your task is to rewrite text as if it was spoken by a different character"
})

def transform(self, dialog: str, context: dict = None) -> Tuple[str, dict]:
Expand Down
25 changes: 17 additions & 8 deletions ovos_solver_openai_persona/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, config=None,
enable_tx=enable_tx, enable_cache=enable_cache,
internal_lang=internal_lang)
self.api_url = f"{self.config.get('api_url', 'https://api.openai.com/v1')}/completions"
self.engine = self.config.get("model", "text-davinci-002") # "ada" cheaper and faster, "davinci" better
self.engine = self.config.get("model", "gpt-4o-mini")
self.key = self.config.get("key")
if not self.key:
LOG.error("key not set in config")
Expand Down Expand Up @@ -107,7 +107,12 @@ def __init__(self, config=None,
self.memory = config.get("enable_memory", True)
self.max_utts = config.get("memory_size", 3)
self.qa_pairs = [] # tuple of q+a
self.initial_prompt = config.get("initial_prompt", "You are a helpful assistant.")
if "initial_prompt" in config:
LOG.warning("'initial_prompt' config option is deprecated, use 'system_prompt' instead")
self.system_prompt = config.get("system_prompt") or config.get("initial_prompt")
if not self.system_prompt:
self.system_prompt = "You are a helpful assistant."
LOG.error(f"system prompt not set in config! defaulting to '{self.system_prompt}'")

# OpenAI API integration
def _do_api_request(self, messages):
Expand Down Expand Up @@ -179,19 +184,19 @@ def _do_streaming_api_request(self, messages):
continue
yield chunk["choices"][0]["delta"]["content"]

def get_chat_history(self, initial_prompt=None):
def get_chat_history(self, system_prompt=None):
qa = self.qa_pairs[-1 * self.max_utts:]
initial_prompt = initial_prompt or self.initial_prompt or "You are a helpful assistant."
system_prompt = system_prompt or self.system_prompt or "You are a helpful assistant."
messages = [
{"role": "system", "content": initial_prompt},
{"role": "system", "content": system_prompt},
]
for q, a in qa:
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": a})
return messages

def get_messages(self, utt, initial_prompt=None) -> MessageList:
messages = self.get_chat_history(initial_prompt)
def get_messages(self, utt, system_prompt=None) -> MessageList:
messages = self.get_chat_history(system_prompt)
messages.append({"role": "user", "content": utt})
return messages

Expand All @@ -209,6 +214,8 @@ def continue_chat(self, messages: MessageList,
Returns:
Optional[str]: The generated response or None if no response could be generated.
"""
if not messages or messages[0]["role"] != "system":
messages = [{"role": "system", "content": self.system_prompt }] + messages
response = self._do_api_request(messages)
answer = post_process_sentence(response)
if not answer or not answer.strip("?") or not answer.strip("_"):
Expand All @@ -218,7 +225,7 @@ def continue_chat(self, messages: MessageList,
self.qa_pairs.append((query, answer))
return answer

def stream_chat_utterances(self, messages: List[Dict[str, str]],
def stream_chat_utterances(self, messages: MessageList,
lang: Optional[str] = None,
units: Optional[str] = None) -> Iterable[str]:
"""
Expand All @@ -232,6 +239,8 @@ def stream_chat_utterances(self, messages: List[Dict[str, str]],
Returns:
Iterable[str]: An iterable of utterances.
"""
if not messages or messages[0]["role"] != "system":
messages = [{"role": "system", "content": self.system_prompt }] + messages
answer = ""
query = messages[-1]["content"]
if self.memory:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_version():


PERSONA_ENTRY_POINT = 'Remote Llama=ovos_solver_openai_persona:LLAMA_DEMO'
PLUGIN_ENTRY_POINT = 'ovos-solver-openai-plugin=ovos_solver_openai_persona:OpenAIPersonaSolver'
PLUGIN_ENTRY_POINT = 'ovos-solver-openai-plugin=ovos_solver_openai_persona.engines:OpenAICompletionsSolver'
DIALOG_PLUGIN_ENTRY_POINT = 'ovos-dialog-transformer-openai-plugin=ovos_solver_openai_persona.dialog_transformers:OpenAIDialogTransformer'
SUMMARIZER_ENTRY_POINT = 'ovos-summarizer-openai-plugin=ovos_solver_openai_persona.summarizer:OpenAISummarizer'

Expand Down