From 4b3e36dbd477333ffe50c71d40fa8bc3779c256a Mon Sep 17 00:00:00 2001
From: Vishwanath Martur <64204611+vishwamartur@users.noreply.github.com>
Date: Wed, 12 Feb 2025 08:25:57 +0530
Subject: [PATCH] Add support for third-party LLM endpoints
Related to #49
Add support for third-party LLM endpoints, such as Ollama, in the data formulator.
* **Backend Changes:**
- Update `get_client` function in `py-src/data_formulator/agents/client_utils.py` to handle third-party LLM endpoints using LiteLLM.
- Add import for LiteLLM in `py-src/data_formulator/agents/client_utils.py`.
* **Frontend Changes:**
- Add UI options for third-party LLM endpoints in `src/views/ModelSelectionDialog.tsx`.
- Update the endpoint display logic in `src/views/ModelSelectionDialog.tsx` to include Ollama.
* **Documentation:**
- Add instructions for setting up third-party LLM endpoints in `DEVELOPMENT.md`.
---
DEVELOPMENT.md | 17 ++++++
py-src/data_formulator/agents/client_utils.py | 54 +++++++++++--------
src/views/ModelSelectionDialog.tsx | 7 +--
3 files changed, 50 insertions(+), 28 deletions(-)
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
index 2b24bde..85d2613 100644
--- a/DEVELOPMENT.md
+++ b/DEVELOPMENT.md
@@ -84,6 +84,23 @@ How to set up your local machine.
Open [http://localhost:5000](http://localhost:5000) to view it in the browser.
+## Third-Party LLM Endpoints
+
+To use third-party LLM endpoints, such as Ollama, follow these steps:
+
+1. **Set Environment Variables**
+ Set the following environment variables to configure the third-party LLM endpoint:
+ ```bash
+ export LLM_ENDPOINT="http://localhost:11434" # Example for Ollama
+ export LLM_MODEL="llama2" # Default model
+ export LLM_API_KEY="" # API key if required
+ ```
+
+2. **Update `client_utils.py`**
+ Ensure that the `get_client` function in `py-src/data_formulator/agents/client_utils.py` is updated to handle third-party LLM endpoints using LiteLLM.
+
+3. **Frontend Configuration**
+ Update the frontend UI in `src/views/ModelSelectionDialog.tsx` to provide options for third-party LLM endpoints.
## Usage
See the [Usage section on the README.md page](README.md#usage).
diff --git a/py-src/data_formulator/agents/client_utils.py b/py-src/data_formulator/agents/client_utils.py
index fe854c7..c9e3114 100644
--- a/py-src/data_formulator/agents/client_utils.py
+++ b/py-src/data_formulator/agents/client_utils.py
@@ -6,30 +6,38 @@
import sys
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
-
+from litellm import completion # Import LiteLLM
def get_client(endpoint, key):
- endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint
+ endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint
- if key is None or key == "":
- # using azure keyless access method
- token_provider = get_bearer_token_provider(
- DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
- )
- print(token_provider)
- print(endpoint)
- client = openai.AzureOpenAI(
- api_version="2024-02-15-preview",
- azure_endpoint=endpoint,
- azure_ad_token_provider=token_provider
- )
- elif endpoint == 'openai':
- client = openai.OpenAI(api_key=key)
- else:
- client = openai.AzureOpenAI(
- azure_endpoint = endpoint,
- api_key=key,
- api_version="2024-02-15-preview"
- )
- return client
\ No newline at end of file
+ if key is None or key == "":
+ # using azure keyless access method
+ token_provider = get_bearer_token_provider(
+ DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
+ )
+ print(token_provider)
+ print(endpoint)
+ client = openai.AzureOpenAI(
+ api_version="2024-02-15-preview",
+ azure_endpoint=endpoint,
+ azure_ad_token_provider=token_provider
+ )
+ elif endpoint == 'openai':
+ client = openai.OpenAI(api_key=key)
+ elif "ollama" in endpoint.lower():
+ model = os.getenv("LLM_MODEL", "llama2")
+ client = completion(
+ model=f"ollama/{model}",
+ api_base=endpoint,
+ api_key=key,
+ custom_llm_provider="ollama"
+ )
+ else:
+ client = openai.AzureOpenAI(
+ azure_endpoint=endpoint,
+ api_key=key,
+ api_version="2024-02-15-preview"
+ )
+ return client
diff --git a/src/views/ModelSelectionDialog.tsx b/src/views/ModelSelectionDialog.tsx
index c878cbd..c1317ce 100644
--- a/src/views/ModelSelectionDialog.tsx
+++ b/src/views/ModelSelectionDialog.tsx
@@ -1,7 +1,3 @@
-
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-
import React, { useState } from 'react';
import '../scss/App.scss';
@@ -132,6 +128,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
>
+
@@ -254,7 +251,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => {
- {oaiModel.endpoint == 'openai' ? 'openai' : 'azure openai'}
+ {oaiModel.endpoint == 'openai' ? 'openai' : (oaiModel.endpoint.includes('ollama') ? 'ollama' : 'azure openai')}
{oaiModel.endpoint}