Skip to content

Commit 4dae76a

Browse files
committed
reuse openai client for chat completion
1 parent bdb869a commit 4dae76a

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

minds/minds.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import List, Union, Iterable
2-
from urllib.parse import urlparse, urlunparse
3-
from datetime import datetime
4-
2+
import utils
53
from openai import OpenAI
64

75
import minds.exceptions as exc
@@ -10,7 +8,6 @@
108

119
DEFAULT_PROMPT_TEMPLATE = 'Use your database tools to answer the user\'s question: {{question}}'
1210

13-
1411
class Mind:
1512
def __init__(
1613
self, client, name,
@@ -35,7 +32,11 @@ def __init__(
3532
self.parameters = parameters
3633
self.created_at = created_at
3734
self.updated_at = updated_at
38-
35+
base_url = utils.create_base_url_openai(self.api.base_url)
36+
self.openai_client = OpenAI(
37+
api_key=self.api.api_key,
38+
base_url=base_url
39+
)
3940
self.datasources = datasources
4041

4142
def __repr__(self):
@@ -156,23 +157,7 @@ def completion(self, message: str, stream: bool = False) -> Union[str, Iterable[
156157
157158
:return: string if stream mode is off or iterator of ChoiceDelta objects (by openai)
158159
"""
159-
parsed = urlparse(self.api.base_url)
160-
161-
netloc = parsed.netloc
162-
if netloc == 'mdb.ai':
163-
llm_host = 'llm.mdb.ai'
164-
else:
165-
llm_host = 'ai.' + netloc
166-
167-
parsed = parsed._replace(path='', netloc=llm_host)
168-
169-
base_url = urlunparse(parsed)
170-
openai_client = OpenAI(
171-
api_key=self.api.api_key,
172-
base_url=base_url
173-
)
174-
175-
response = openai_client.chat.completions.create(
160+
response = self.openai_client.chat.completions.create(
176161
model=self.name,
177162
messages=[
178163
{'role': 'user', 'content': message}

minds/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from urllib.parse import urlparse, urlunparse
2+
3+
def create_base_url_openai(base_url: str) -> str:
4+
parsed = urlparse(base_url)
5+
6+
netloc = parsed.netloc
7+
if netloc == 'mdb.ai':
8+
llm_host = 'llm.mdb.ai'
9+
else:
10+
llm_host = 'ai.' + netloc
11+
12+
parsed = parsed._replace(path='', netloc=llm_host)
13+
14+
return urlunparse(parsed)
15+

0 commit comments

Comments
 (0)