Skip to content

Commit 76828fd

Browse files
added custom llm and embeddings instance
1 parent 2950740 commit 76828fd

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

src/intugle/core/llms/chat.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Optional
44

55
from langchain.chat_models import init_chat_model
66
from langchain.output_parsers import (
@@ -30,7 +30,7 @@ class ChatModelLLM:
3030

3131
def __init__(
3232
self,
33-
model_name: str,
33+
model_name: Optional[str] = None,
3434
response_schemas: list[ResponseSchema] = None,
3535
output_parser=StructuredOutputParser,
3636
prompt_template=ChatPromptTemplate,
@@ -39,9 +39,14 @@ def __init__(
3939
*args,
4040
**kwargs,
4141
):
42-
self.model: BaseChatModel = init_chat_model(
43-
model_name, max_retries=self.MAX_RETRIES, rate_limiter=self._get_rate_limiter(), **config
44-
) # llm model
42+
if settings.CUSTOM_LLM_INSTANCE:
43+
self.model: "BaseChatModel" = settings.CUSTOM_LLM_INSTANCE
44+
elif model_name:
45+
self.model: "BaseChatModel" = init_chat_model(
46+
model_name, max_retries=self.MAX_RETRIES, rate_limiter=self._get_rate_limiter(), **config
47+
)
48+
else:
49+
raise ValueError("Either 'settings.CUSTOM_LLM_INSTANCE' must be set or 'LLM_PROVIDER' must be provided.")
4550

4651
self.parser: StructuredOutputParser = output_parser # the output parser
4752

@@ -135,6 +140,8 @@ def invoke(self, *args, **kwargs):
135140

136141
@classmethod
137142
def get_llm(cls, model_name: str, llm_config: dict = {}):
143+
if settings.CUSTOM_LLM_INSTANCE:
144+
return settings.CUSTOM_LLM_INSTANCE
138145
return init_chat_model(
139146
model_name, max_retries=cls.MAX_RETRIES, rate_limiter=cls._get_rate_limiter(), **llm_config
140147
)

src/intugle/core/llms/embeddings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from langchain.embeddings.base import init_embeddings
1010

11+
from intugle.core import settings
12+
1113

1214
class EmbeddingsType(str, Enum):
1315
DENSE = "dense"
@@ -30,7 +32,10 @@ def __init__(
3032
embeddings_size: Optional[int] = None,
3133
):
3234
self.model_name = model_name
33-
self.model = init_embeddings(model_name, **config)
35+
if settings.CUSTOM_EMBEDDINGS_INSTANCE:
36+
self.model = settings.CUSTOM_EMBEDDINGS_INSTANCE
37+
else:
38+
self.model = init_embeddings(model_name, **config)
3439
self._embed_func: Dict[EmbeddingsType, Callable] = {
3540
EmbeddingsType.DENSE: self.dense,
3641
EmbeddingsType.SPARSE: self.sparse,

src/intugle/core/settings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from functools import lru_cache
66
from pathlib import Path
7-
from typing import Optional
7+
from typing import Any, Optional
88

99
from dotenv import load_dotenv
1010
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -66,6 +66,8 @@ class Settings(BaseSettings):
6666
MAX_RETRIES: int = 5
6767
SLEEP_TIME: int = 25
6868
ENABLE_RATE_LIMITER: bool = False
69+
CUSTOM_LLM_INSTANCE: Optional[Any] = None
70+
CUSTOM_EMBEDDINGS_INSTANCE: Optional[Any] = None
6971

7072
# LP
7173
HALLUCINATIONS_MAX_RETRY: int = 2

0 commit comments

Comments
 (0)