Skip to content

Add support for third-party LLM endpoints #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ How to set up your local machine.

Open [http://localhost:5000](http://localhost:5000) to view it in the browser.

## Third-Party LLM Endpoints

To use third-party LLM endpoints, such as Ollama, follow these steps:

1. **Set Environment Variables**
Set the following environment variables to configure the third-party LLM endpoint:
```bash
export LLM_ENDPOINT="http://localhost:11434" # Example for Ollama
export LLM_MODEL="llama2" # Default model
export LLM_API_KEY="" # API key if required
```

2. **Update `client_utils.py`**
Ensure that the `get_client` function in `py-src/data_formulator/agents/client_utils.py` is updated to handle third-party LLM endpoints using LiteLLM.

3. **Frontend Configuration**
Update the frontend UI in `src/views/ModelSelectionDialog.tsx` to provide options for third-party LLM endpoints.

## Usage
See the [Usage section on the README.md page](README.md#usage).
54 changes: 31 additions & 23 deletions py-src/data_formulator/agents/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,38 @@
import sys

from azure.identity import DefaultAzureCredential, get_bearer_token_provider

from litellm import completion # Import LiteLLM

def get_client(endpoint, key):

endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint
endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint

if key is None or key == "":
# using azure keyless access method
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
print(token_provider)
print(endpoint)
client = openai.AzureOpenAI(
api_version="2024-02-15-preview",
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider
)
elif endpoint == 'openai':
client = openai.OpenAI(api_key=key)
else:
client = openai.AzureOpenAI(
azure_endpoint = endpoint,
api_key=key,
api_version="2024-02-15-preview"
)
return client
if key is None or key == "":
# using azure keyless access method
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
print(token_provider)
print(endpoint)
client = openai.AzureOpenAI(
api_version="2024-02-15-preview",
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider
)
elif endpoint == 'openai':
client = openai.OpenAI(api_key=key)
elif "ollama" in endpoint.lower():
model = os.getenv("LLM_MODEL", "llama2")
client = completion(
model=f"ollama/{model}",
api_base=endpoint,
api_key=key,
custom_llm_provider="ollama"
)
else:
client = openai.AzureOpenAI(
azure_endpoint=endpoint,
api_key=key,
api_version="2024-02-15-preview"
)
return client
7 changes: 2 additions & 5 deletions src/views/ModelSelectionDialog.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

Comment on lines -1 to -4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you removing the copyright notice?

import React, { useState } from 'react';
import '../scss/App.scss';

Expand Down Expand Up @@ -132,6 +128,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
>
<MenuItem sx={{fontSize: '0.875rem' }} value="openai">openai</MenuItem>
<MenuItem sx={{fontSize: '0.875rem' }} value="azureopenai">azure openai</MenuItem>
<MenuItem sx={{fontSize: '0.875rem' }} value="ollama">ollama</MenuItem>
</Select>
</FormControl>
</TableCell>
Expand Down Expand Up @@ -254,7 +251,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
<Radio checked={isItemSelected} name="radio-buttons" inputProps={{'aria-label': 'Select this model'}} />
</TableCell>
<TableCell align="left" sx={{ borderBottom: noBorderStyle }}>
{oaiModel.endpoint == 'openai' ? 'openai' : 'azure openai'}
{oaiModel.endpoint == 'openai' ? 'openai' : (oaiModel.endpoint.includes('ollama') ? 'ollama' : 'azure openai')}
</TableCell>
<TableCell component="th" scope="row" sx={{ borderBottom: borderStyle }}>
{oaiModel.endpoint}
Expand Down