From 3475bf7cb64b1a129b8b983e011602bfa5e4ed90 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Sun, 22 Dec 2024 23:09:29 +0000 Subject: [PATCH] Add Support for Local LLama Models via LlamaCpp Integration --- .env.sample | 7 ++++++- gpt_all_star/core/llm.py | 28 ++++++++++++++++++++++++++++ tests/core/test_llm.py | 38 +++++++++++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/.env.sample b/.env.sample index 5be0fd4e..11b7c3e1 100644 --- a/.env.sample +++ b/.env.sample @@ -1,4 +1,4 @@ -# OPENAI or AZURE or ANTHROPIC +# OPENAI or AZURE or ANTHROPIC or LLAMA ENDPOINT=OPENAI # USE when ENDPOINT=OPENAI @@ -16,6 +16,11 @@ AZURE_OPENAI_ENDPOINT=https://.openai.azure.com/ ANTHROPIC_API_KEY= ANTHROPIC_API_MODEL= +# USE when ENDPOINT=LLAMA +LLAMA_MODEL_PATH=/path/to/llama/model.gguf +LLAMA_N_CTX=2048 +LLAMA_N_GPU_LAYERS=0 + # LangSmith LANGCHAIN_TRACING_V2=false LANGCHAIN_ENDPOINT=https://api.smith.langchain.com diff --git a/gpt_all_star/core/llm.py b/gpt_all_star/core/llm.py index ac1b54be..dd96d448 100644 --- a/gpt_all_star/core/llm.py +++ b/gpt_all_star/core/llm.py @@ -5,12 +5,16 @@ from langchain_anthropic import ChatAnthropic from langchain_core.language_models.chat_models import BaseChatModel from langchain_openai import AzureChatOpenAI, ChatOpenAI +from langchain_community.llms import LlamaCpp +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler class LLM_TYPE(str, Enum): OPENAI = "OPENAI" AZURE = "AZURE" ANTHROPIC = "ANTHROPIC" + LLAMA = "LLAMA" def create_llm(llm_name: LLM_TYPE) -> BaseChatModel: @@ -37,6 +41,13 @@ def create_llm(llm_name: LLM_TYPE) -> BaseChatModel: model_name=os.getenv("ANTHROPIC_API_MODEL", "claude-3-opus-20240229"), temperature=0.1, ) + elif llm_name == LLM_TYPE.LLAMA: + return _create_chat_llama( + model_path=os.getenv("LLAMA_MODEL_PATH"), + temperature=0.1, + n_ctx=int(os.getenv("LLAMA_N_CTX", "2048")), + n_gpu_layers=int(os.getenv("LLAMA_N_GPU_LAYERS", "0")), + ) else: raise ValueError(f"Unsupported LLM type: {llm_name}") @@ -83,3 +94,20 @@ def _create_chat_anthropic( temperature=temperature, streaming=True, ) + + +def _create_chat_llama( + model_path: str, + temperature: float, + n_ctx: int = 2048, + n_gpu_layers: int = 0, +) -> LlamaCpp: + callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) + return LlamaCpp( + model_path=model_path, + temperature=temperature, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers, + callback_manager=callback_manager, + verbose=True, + ) diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py index b4b35aae..ff2337ee 100644 --- a/tests/core/test_llm.py +++ b/tests/core/test_llm.py @@ -2,7 +2,7 @@ import pytest -from gpt_all_star.core.llm import _create_chat_openai +from gpt_all_star.core.llm import _create_chat_openai, _create_chat_llama @pytest.fixture @@ -17,6 +17,18 @@ def mock_chat_openai(): yield mock +@pytest.fixture +def mock_llamacpp(): + with patch("gpt_all_star.core.llm.LlamaCpp") as mock: + yield mock + + +@pytest.fixture +def mock_callback_manager(): + with patch("gpt_all_star.core.llm.CallbackManager") as mock: + yield mock + + def test_create_chat_openai_with_base_url(mock_openai, mock_chat_openai): base_url = "https://custom-openai-api.com/v1" _create_chat_openai(model_name="gpt-4", temperature=0.1, base_url=base_url) @@ -40,3 +52,27 @@ def test_create_chat_openai_without_base_url(mock_openai, mock_chat_openai): client=mock_openai.chat.completions, openai_api_base=None, ) + + +def test_create_chat_llama(mock_llamacpp, mock_callback_manager): + model_path = "/path/to/model.gguf" + temperature = 0.1 + n_ctx = 2048 + n_gpu_layers = 0 + + _create_chat_llama( + model_path=model_path, + temperature=temperature, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers, + ) + + mock_callback_manager.assert_called_once() + mock_llamacpp.assert_called_once_with( + model_path=model_path, + temperature=temperature, + n_ctx=n_ctx, + n_gpu_layers=n_gpu_layers, + callback_manager=mock_callback_manager.return_value, + verbose=True, + )