Skip to content

Commit 3326ab9

Browse files
authored
Support custom endpoints and models for LLMChat (#88)
1 parent 0d238c0 commit 3326ab9

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

dailalib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "3.17.0"
1+
__version__ = "3.18.0"
22

33
import os
44
# stop LiteLLM from querying at all to the remote server

dailalib/api/litellm/litellm_api.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
active_model = None
1515
active_prompt_style = None
16+
active_custom_endpoint = None
1617

1718
_l = logging.getLogger(__name__)
1819

@@ -64,9 +65,10 @@ def __init__(
6465
self.prompts_by_name = {p.name: p for p in prompts}
6566

6667
# update the globals (for threading hacks)
67-
global active_model, active_prompt_style
68-
active_model = self.model
68+
global active_model, active_prompt_style, active_custom_endpoint
69+
active_model = self.model if not self.custom_model else self.custom_model
6970
active_prompt_style = self.prompt_style
71+
active_custom_endpoint = self.custom_endpoint
7072

7173
def load_or_create_config(self, new_config=None) -> bool:
7274
if new_config:
@@ -195,7 +197,6 @@ def query_model(
195197
# get the answer
196198
try:
197199
answer = response.choices[0].message.content
198-
if self.custom_endpoint: print(answer)
199200
except (KeyError, IndexError) as e:
200201
answer = None
201202

@@ -261,10 +262,19 @@ def _set_model(self, model):
261262
global active_model
262263
active_model = model
263264

265+
def _set_custom_endpoint(self, custom_endpoint):
266+
self.custom_endpoint = custom_endpoint
267+
global active_custom_endpoint
268+
active_custom_endpoint = custom_endpoint
269+
264270
def get_model(self):
265271
# TODO: this hack needs to be refactored later
266-
global active_model
267-
return str(active_model)
272+
global active_model, active_custom_endpoint
273+
return str(active_model) if not active_custom_endpoint else str(active_custom_endpoint)
274+
275+
def get_custom_endpoint(self):
276+
global active_custom_endpoint
277+
return active_custom_endpoint
268278

269279
#
270280
# LLM Settings

dailalib/llm_chat/llm_chat_ui.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class LLMChatClient(QWidget):
2828
def __init__(self, ai_api: "LiteLLMAIAPI", parent=None, context: Context = None):
2929
super(LLMChatClient, self).__init__(parent)
3030
self.model = ai_api.get_model()
31+
self.custom_endpoint = ai_api.get_custom_endpoint()
3132
self.ai_api = ai_api
3233
self.context = context
3334
self.setWindowTitle('LLM Chat')
@@ -65,6 +66,7 @@ def __init__(self, ai_api: "LiteLLMAIAPI", parent=None, context: Context = None)
6566

6667
# Chat history
6768
self.chat_history = []
69+
self.thread = None
6870

6971
# model check
7072
if not self.model:
@@ -152,7 +154,9 @@ def send_message(self, add_text=True, role="user"):
152154
self.send_button.setDisabled(True)
153155

154156
# Start a thread to get the response
155-
self.thread = LLMThread(self.chat_history, self.model)
157+
self.thread = LLMThread(
158+
self.chat_history, self.model, custom_endpoint=self.custom_endpoint, api_key=self.ai_api.api_key
159+
)
156160
self.thread.response_received.connect(lambda msg: self.receive_message(msg))
157161
self.thread.start()
158162

@@ -183,10 +187,12 @@ def closeEvent(self, event):
183187
class LLMThread(QThread):
184188
response_received = pyqtSignal(str)
185189

186-
def __init__(self, chat_history, model_name):
190+
def __init__(self, chat_history, model_name, custom_endpoint=None, api_key=None):
187191
super().__init__()
188192
self.chat_history = chat_history.copy()
189193
self.model_name = model_name
194+
self.custom_endpoint = custom_endpoint
195+
self.api_key = api_key
190196

191197
def run(self):
192198
import litellm
@@ -196,8 +202,9 @@ def run(self):
196202
response = litellm.completion(
197203
model=self.model_name,
198204
messages=self.chat_history,
199-
timeout=60,
200-
205+
timeout=60 if not self.custom_endpoint else 300,
206+
api_base=self.custom_endpoint if self.custom_endpoint else None, # Use custom endpoint if set
207+
api_key=self.api_key if not self.custom_endpoint else "dummy" # In most of cases custom endpoint doesn't need the api_key
201208
)
202209
litellm.modify_params = False
203210

0 commit comments

Comments
 (0)