From 4b3e36dbd477333ffe50c71d40fa8bc3779c256a Mon Sep 17 00:00:00 2001 From: Vishwanath Martur <64204611+vishwamartur@users.noreply.github.com> Date: Wed, 12 Feb 2025 08:25:57 +0530 Subject: [PATCH] Add support for third-party LLM endpoints Related to #49 Add support for third-party LLM endpoints, such as Ollama, in the data formulator. * **Backend Changes:** - Update `get_client` function in `py-src/data_formulator/agents/client_utils.py` to handle third-party LLM endpoints using LiteLLM. - Add import for LiteLLM in `py-src/data_formulator/agents/client_utils.py`. * **Frontend Changes:** - Add UI options for third-party LLM endpoints in `src/views/ModelSelectionDialog.tsx`. - Update the endpoint display logic in `src/views/ModelSelectionDialog.tsx` to include Ollama. * **Documentation:** - Add instructions for setting up third-party LLM endpoints in `DEVELOPMENT.md`. --- DEVELOPMENT.md | 17 ++++++ py-src/data_formulator/agents/client_utils.py | 54 +++++++++++-------- src/views/ModelSelectionDialog.tsx | 7 +-- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 2b24bde..85d2613 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -84,6 +84,23 @@ How to set up your local machine. Open [http://localhost:5000](http://localhost:5000) to view it in the browser. +## Third-Party LLM Endpoints + +To use third-party LLM endpoints, such as Ollama, follow these steps: + +1. **Set Environment Variables** + Set the following environment variables to configure the third-party LLM endpoint: + ```bash + export LLM_ENDPOINT="http://localhost:11434" # Example for Ollama + export LLM_MODEL="llama2" # Default model + export LLM_API_KEY="" # API key if required + ``` + +2. **Update `client_utils.py`** + Ensure that the `get_client` function in `py-src/data_formulator/agents/client_utils.py` is updated to handle third-party LLM endpoints using LiteLLM. + +3. **Frontend Configuration** + Update the frontend UI in `src/views/ModelSelectionDialog.tsx` to provide options for third-party LLM endpoints. ## Usage See the [Usage section on the README.md page](README.md#usage). diff --git a/py-src/data_formulator/agents/client_utils.py b/py-src/data_formulator/agents/client_utils.py index fe854c7..c9e3114 100644 --- a/py-src/data_formulator/agents/client_utils.py +++ b/py-src/data_formulator/agents/client_utils.py @@ -6,30 +6,38 @@ import sys from azure.identity import DefaultAzureCredential, get_bearer_token_provider - +from litellm import completion # Import LiteLLM def get_client(endpoint, key): - endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint + endpoint = os.getenv("ENDPOINT") if endpoint == "default" else endpoint - if key is None or key == "": - # using azure keyless access method - token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" - ) - print(token_provider) - print(endpoint) - client = openai.AzureOpenAI( - api_version="2024-02-15-preview", - azure_endpoint=endpoint, - azure_ad_token_provider=token_provider - ) - elif endpoint == 'openai': - client = openai.OpenAI(api_key=key) - else: - client = openai.AzureOpenAI( - azure_endpoint = endpoint, - api_key=key, - api_version="2024-02-15-preview" - ) - return client \ No newline at end of file + if key is None or key == "": + # using azure keyless access method + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + print(token_provider) + print(endpoint) + client = openai.AzureOpenAI( + api_version="2024-02-15-preview", + azure_endpoint=endpoint, + azure_ad_token_provider=token_provider + ) + elif endpoint == 'openai': + client = openai.OpenAI(api_key=key) + elif "ollama" in endpoint.lower(): + model = os.getenv("LLM_MODEL", "llama2") + client = completion( + model=f"ollama/{model}", + api_base=endpoint, + api_key=key, + custom_llm_provider="ollama" + ) + else: + client = openai.AzureOpenAI( + azure_endpoint=endpoint, + api_key=key, + api_version="2024-02-15-preview" + ) + return client diff --git a/src/views/ModelSelectionDialog.tsx b/src/views/ModelSelectionDialog.tsx index c878cbd..c1317ce 100644 --- a/src/views/ModelSelectionDialog.tsx +++ b/src/views/ModelSelectionDialog.tsx @@ -1,7 +1,3 @@ - -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - import React, { useState } from 'react'; import '../scss/App.scss'; @@ -132,6 +128,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { > openai azure openai + ollama @@ -254,7 +251,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { - {oaiModel.endpoint == 'openai' ? 'openai' : 'azure openai'} + {oaiModel.endpoint == 'openai' ? 'openai' : (oaiModel.endpoint.includes('ollama') ? 'ollama' : 'azure openai')} {oaiModel.endpoint}