Skip to content

Commit e5323e6

Browse files
committed
update supports for other models (tested with openai, azure, anthropic, ollama with codellama)
1 parent 5e856b0 commit e5323e6

File tree

6 files changed

+103
-80
lines changed

6 files changed

+103
-80
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

2-
3-
*openai-keys.env
2+
*api-keys.env
43
**/*.ipynb_checkpoints/
54

65
.DS_Store

api-keys.env.template

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# OpenAI Configuration
2+
OPENAI_ENABLED=true
3+
OPENAI_API_KEY=#your-openai-api-key
4+
OPENAI_MODELS=gpt-4o,gpt-4o-mini
5+
6+
# Azure OpenAI Configuration
7+
AZURE_ENABLED=true
8+
AZURE_API_BASE=https://your-azure-openai-endpoint.openai.azure.com/
9+
AZURE_API_VERSION=2024-02-15-preview
10+
AZURE_MODELS=gpt-4o
11+
12+
# Anthropic Configuration
13+
ANTHROPIC_ENABLED=true
14+
ANTHROPIC_API_KEY=#your-anthropic-api-key
15+
ANTHROPIC_MODELS=claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022
16+
17+
# Ollama Configuration
18+
OLLAMA_ENABLED=true
19+
OLLAMA_API_BASE=http://localhost:11434
20+
OLLAMA_MODELS=codellama

py-src/data_formulator/agents/client_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,8 @@ class Client(object):
1010
"""
1111
def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=None):
1212

13-
if endpoint == "default":
14-
self.endpoint = os.getenv("ENDPOINT", "azure_openai")
15-
self.model = model
16-
api_base = os.getenv("API_BASE")
17-
else:
18-
self.endpoint = endpoint
19-
self.model = model
13+
self.endpoint = endpoint
14+
self.model = model
2015

2116
# other params, including temperature, max_completion_tokens, api_base, api_version
2217
self.params = {
@@ -35,7 +30,7 @@ def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=No
3530
self.model = model
3631
else:
3732
self.model = f"anthropic/{model}"
38-
elif self.endpoint == "azure_openai":
33+
elif self.endpoint == "azure":
3934
self.params["api_base"] = api_base
4035
self.params["api_version"] = api_version if api_version else "2024-02-15-preview"
4136
if api_key is None or api_key == "":

py-src/data_formulator/app.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@
4343

4444
print(APP_ROOT)
4545

46-
# try to look for stored openAI keys information from the ROOT dir,
47-
# this file might be in one of the two locations
48-
load_dotenv(os.path.join(APP_ROOT, "..", "..", 'openai-keys.env'))
49-
load_dotenv(os.path.join(APP_ROOT, 'openai-keys.env'))
46+
# Load the single environment file
47+
load_dotenv(os.path.join(APP_ROOT, "..", "..", 'api-keys.env'))
48+
load_dotenv(os.path.join(APP_ROOT, 'api-keys.env'))
5049

5150
import os
5251

@@ -131,46 +130,56 @@ def get_datasets(path):
131130

132131
@app.route('/check-available-models', methods=['GET', 'POST'])
133132
def check_available_models():
134-
135133
results = []
134+
135+
# Define configurations for different providers
136+
providers = ['openai', 'azure', 'anthropic', 'gemini', 'ollama']
136137

137-
# dont need to check if it's empty
138-
if os.getenv("ENDPOINT") is None:
139-
return json.dumps(results)
140-
141-
endpoint = os.getenv("ENDPOINT")
142-
models = [model.strip() for model in os.getenv("MODELS").split(',')]
143-
api_base = os.getenv("API_BASE")
144-
145-
print("endpoint", endpoint)
146-
print("models", models)
147-
print("api_base", api_base)
148-
149-
for model in models:
150-
try:
151-
client = Client(endpoint, model, api_key=None, api_base=api_base, api_version=None)
152-
response = client.get_completion(
153-
messages=[
154-
{"role": "system", "content": "You are a helpful assistant."},
155-
{"role": "user", "content": "Respond 'I can hear you.' if you can hear me. Do not say anything other than 'I can hear you.'"},
156-
]
157-
)
158-
159-
print(f"model: {model}")
160-
print(f"welcome message: {response.choices[0].message.content}")
161-
162-
if "I can hear you." in response.choices[0].message.content:
163-
results.append({
164-
"id": f"default-{model}",
165-
"endpoint": "default",
166-
"key": "",
167-
"model": model
168-
})
169-
except Exception as e:
170-
print(f"Error: {e}")
171-
error_message = str(e)
172-
173-
138+
for provider in providers:
139+
# Skip if provider is not enabled
140+
if not os.getenv(f"{provider.upper()}_ENABLED", "").lower() == "true":
141+
continue
142+
143+
api_key = os.getenv(f"{provider.upper()}_API_KEY", "")
144+
api_base = os.getenv(f"{provider.upper()}_API_BASE", "")
145+
api_version = os.getenv(f"{provider.upper()}_API_VERSION", "")
146+
models = os.getenv(f"{provider.upper()}_MODELS", "")
147+
148+
if not (api_key or api_base):
149+
continue
150+
151+
if not models:
152+
continue
153+
154+
# Build config for each model
155+
for model in models.split(","):
156+
model = model.strip()
157+
if not model:
158+
continue
159+
160+
model_config = {
161+
"id": f"{provider}-{model}-{api_key}-{api_base}-{api_version}",
162+
"endpoint": provider,
163+
"model": model,
164+
"api_key": api_key,
165+
"api_base": api_base,
166+
"api_version": api_version
167+
}
168+
169+
try:
170+
client = get_client(model_config)
171+
response = client.get_completion(
172+
messages=[
173+
{"role": "system", "content": "You are a helpful assistant."},
174+
{"role": "user", "content": "Respond 'I can hear you.' if you can hear me."},
175+
]
176+
)
177+
178+
if "I can hear you." in response.choices[0].message.content:
179+
results.append(model_config)
180+
except Exception as e:
181+
print(f"Error testing {provider} model {model}: {e}")
182+
174183
return json.dumps(results)
175184

176185
@app.route('/test-model', methods=['GET', 'POST'])

src/app/dfSlice.tsx

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ export const fetchCodeExpl = createAsyncThunk(
203203
export const fetchAvailableModels = createAsyncThunk(
204204
"dataFormulatorSlice/fetchAvailableModels",
205205
async () => {
206-
console.log(">>> call agent to infer semantic types <<<")
206+
console.log(">>> call agent to fetch available models <<<")
207207
let message = {
208208
method: 'POST',
209209
headers: { 'Content-Type': 'application/json', },
@@ -256,7 +256,7 @@ export const dataFormulatorSlice = createSlice({
256256

257257
let savedState = action.payload;
258258

259-
state.models = savedState.models.filter((m: any) => m.endpoint != 'default');
259+
state.models = savedState.models;
260260
state.selectedModelId = state.models.length > 0 ? state.models[0].id : undefined;
261261
state.testedModels = []; // models should be tested again
262262

@@ -651,23 +651,25 @@ export const dataFormulatorSlice = createSlice({
651651
})
652652
.addCase(fetchAvailableModels.fulfilled, (state, action) => {
653653
let defaultModels = action.payload;
654-
state.models = [...defaultModels, ...state.models.filter(e => !defaultModels.map((m: any) => m.endpoint).includes(e.endpoint))];
655-
656-
console.log("defaultModels", defaultModels);
657-
console.log("state.models", state.models);
658-
console.log("state.testedModels", state.testedModels);
659654

655+
state.models = [
656+
...defaultModels,
657+
...state.models.filter(e => !defaultModels.map((m: ModelConfig) => m.endpoint).includes(e.endpoint))
658+
];
659+
660660
state.testedModels = [
661-
...defaultModels.map((m: any) => {return {id: `default-${m.model}`, status: 'ok'}}) ,
662-
...state.testedModels.filter(t => !defaultModels.map((m: any) => m.endpoint).includes(t.id))
661+
...defaultModels.map((m: ModelConfig) => {return {id: m.id, status: 'ok'}}) ,
662+
...state.testedModels.filter(t => !defaultModels.map((m: ModelConfig) => m.id).includes(t.id))
663663
]
664664

665665
if (state.selectedModelId == undefined && defaultModels.length > 0) {
666666
state.selectedModelId = defaultModels[0].id;
667667
}
668-
669-
console.log("fetched models");
670-
console.log(action.payload);
668+
669+
console.log("load model complete");
670+
console.log("state.models", state.models);
671+
console.log("state.selectedModelId", state.selectedModelId);
672+
console.log("state.testedModels", state.testedModels);
671673
})
672674
.addCase(fetchCodeExpl.fulfilled, (state, action) => {
673675
let codeExpl = action.payload;
@@ -684,7 +686,7 @@ export const dataFormulatorSlice = createSlice({
684686

685687
export const dfSelectors = {
686688
getActiveModel: (state: DataFormulatorState) : ModelConfig => {
687-
return state.models.find(m => m.id == state.selectedModelId) || {'endpoint': 'default', model: 'gpt-4o', id: 'default-gpt-4o'}
689+
return state.models.find(m => m.id == state.selectedModelId) || state.models[0];
688690
}
689691
}
690692

src/views/ModelSelectionDialog.tsx

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
import React, { useState } from 'react';
4+
import React, { useEffect, useState } from 'react';
55
import '../scss/App.scss';
66

77
import { useDispatch, useSelector } from "react-redux";
@@ -90,16 +90,16 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
9090
return testedModels.find(t => (t.id == id))?.status || 'unknown';
9191
}
9292

93-
const [newEndpoint, setNewEndpoint] = useState<string>(""); // openai, azure_openai, ollama etc
93+
const [newEndpoint, setNewEndpoint] = useState<string>(""); // openai, azure, ollama etc
9494
const [newModel, setNewModel] = useState<string>("");
9595
const [newApiKey, setNewApiKey] = useState<string | undefined>(undefined);
9696
const [newApiBase, setNewApiBase] = useState<string | undefined>(undefined);
9797
const [newApiVersion, setNewApiVersion] = useState<string | undefined>(undefined);
9898

99-
let disableApiKey = newEndpoint == "default" || newEndpoint == "" || newEndpoint == "ollama";
100-
let disableModel = newEndpoint == "default" || newEndpoint == "";
101-
let disableApiBase = newEndpoint != "azure_openai";
102-
let disableApiVersion = newEndpoint != "azure_openai";
99+
let disableApiKey = newEndpoint == "" || newEndpoint == "ollama";
100+
let disableModel = newEndpoint == "";
101+
let disableApiBase = newEndpoint != "azure";
102+
let disableApiVersion = newEndpoint != "azure";
103103

104104
let modelExists = models.some(m => m.endpoint == newEndpoint && m.model == newModel && m.api_base == newApiBase && m.api_key == newApiKey && m.api_version == newApiVersion);
105105

@@ -123,13 +123,10 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
123123
}
124124

125125
let readyToTest = false;
126-
if (newEndpoint != "default") {
127-
readyToTest = true;
128-
}
129126
if (newEndpoint == "openai") {
130127
readyToTest = newModel != "";
131128
}
132-
if (newEndpoint == "azure_openai") {
129+
if (newEndpoint == "azure") {
133130
readyToTest = newModel != "" && newApiBase != "";
134131
}
135132
if (newEndpoint == "ollama") {
@@ -153,11 +150,11 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
153150
if (newModel == "" && newValue == "openai") {
154151
setNewModel("gpt-4o");
155152
}
156-
if (!newApiVersion && newValue == "azure_openai") {
153+
if (!newApiVersion && newValue == "azure") {
157154
setNewApiVersion("2024-02-15");
158155
}
159156
}}
160-
options={['openai', 'azure_openai', 'ollama', 'gemini', 'anthropic']}
157+
options={['openai', 'azure', 'ollama', 'gemini', 'anthropic']}
161158
renderOption={(props, option) => (
162159
<Typography {...props} onClick={() => setNewEndpoint(option)} sx={{fontSize: "0.875rem"}}>
163160
{option}
@@ -203,7 +200,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
203200
disabled={disableModel}
204201
onChange={(event: any, newValue: string | null) => { setNewModel(newValue || ""); }}
205202
value={newModel}
206-
options={['gpt-35-turbo', 'gpt-4', 'gpt-4o', 'llama3.2']}
203+
options={['gpt-4o-mini', 'gpt-4', 'llama3.2']}
207204
renderOption={(props, option) => {
208205
return <Typography {...props} onClick={()=>{ setNewModel(option); }} sx={{fontSize: "small"}}>{option}</Typography>
209206
}}
@@ -241,7 +238,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
241238
value={newApiBase} onChange={(event: any) => { setNewApiBase(event.target.value); }}
242239
autoComplete='off'
243240
disabled={disableApiBase}
244-
required={newEndpoint == "azure_openai"}
241+
required={newEndpoint == "azure"}
245242
/>
246243
</TableCell>
247244
<TableCell align="right">
@@ -419,7 +416,8 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
419416
{newModelEntry}
420417
<TableRow>
421418
<TableCell colSpan={8} align="left" sx={{fontSize: "0.625rem"}}>
422-
model configuration based on LiteLLM, check out supported endpoint / model configurations <a href="https://docs.litellm.ai/docs/" target="_blank" rel="noopener noreferrer">here.</a>
419+
Model configuration based on LiteLLM, <a href="https://docs.litellm.ai/docs/" target="_blank" rel="noopener noreferrer">check out supported endpoint / models here</a>.
420+
Models with limited code generation capabilities (e.g., llama3.2) may fail frequently to derive new data.
423421
</TableCell>
424422
</TableRow>
425423
</TableBody>

0 commit comments

Comments
 (0)