Skip to content

Commit 7d186df

Browse files
committed
use openai api for openai types of requests, making it compatible with some other providers
1 parent 1d92bd7 commit 7d186df

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

py-src/data_formulator/agents/client_utils.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from litellm import completion
2+
import litellm
3+
import openai
34
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
45

56
class Client(object):
@@ -14,11 +15,17 @@ def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=No
1415

1516
# other params, including temperature, max_completion_tokens, api_base, api_version
1617
self.params = {
17-
"api_key": api_key,
1818
"temperature": 0.7,
1919
"max_completion_tokens": 1200,
2020
}
2121

22+
if api_key is not None and api_key != "":
23+
self.params["api_key"] = api_key
24+
if api_base is not None and api_base != "":
25+
self.params["api_base"] = api_base
26+
if api_version is not None and api_version != "":
27+
self.params["api_version"] = api_version
28+
2229
if self.endpoint == "gemini":
2330
if model.startswith("gemini/"):
2431
self.model = model
@@ -52,9 +59,19 @@ def get_completion(self, messages):
5259
Supports OpenAI, Azure, Ollama, and other providers via LiteLLM.
5360
"""
5461
# Configure LiteLLM
55-
return completion(
56-
model=self.model,
57-
messages=messages,
58-
drop_params=True,
59-
**self.params
60-
)
62+
63+
if self.endpoint == "openai":
64+
client = openai.OpenAI(api_key=self.params["api_key"], base_url=self.params["api_base"] if "api_base" in self.params else None)
65+
66+
return client.chat.completions.create(
67+
model=self.model,
68+
messages=messages,
69+
**{k: v for k, v in self.params.items() if k != "api_key"}
70+
)
71+
else:
72+
return litellm.completion(
73+
model=self.model,
74+
messages=messages,
75+
drop_params=True,
76+
**self.params
77+
)

src/views/ModelSelectionDialog.tsx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
181181
PaperComponent={({ children }) => (
182182
<Paper>
183183
<Typography sx={{ p: 1, color: 'gray', fontStyle: 'italic', fontSize: '0.75rem' }}>
184-
suggestions
184+
examples
185185
</Typography>
186186
{children}
187187
</Paper>
@@ -226,7 +226,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
226226
PaperComponent={({ children }) => (
227227
<Paper>
228228
<Typography sx={{ p: 1, color: 'gray', fontStyle: 'italic', fontSize: 'small' }}>
229-
suggestions
229+
examples
230230
</Typography>
231231
{children}
232232
</Paper>
@@ -267,23 +267,17 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
267267

268268
let model = {endpoint, model: newModel, api_key: newApiKey, api_base: newApiBase, api_version: newApiVersion, id: id};
269269

270-
console.log("checkpont 2")
271-
272270
dispatch(dfActions.addModel(model));
273271
dispatch(dfActions.selectModel(id));
274272
setTempSelectedModeId(id);
275273

276-
console.log("checkpont 3")
277-
278274
testModel(model);
279275

280276
setNewEndpoint("");
281277
setNewModel("");
282278
setNewApiKey(undefined);
283279
setNewApiBase(undefined);
284280
setNewApiVersion(undefined);
285-
286-
console.log("checkpont 4")
287281
}}>
288282
<AddCircleIcon />
289283
</IconButton>

0 commit comments

Comments
 (0)