11import logging
22
3- from typing import TYPE_CHECKING
3+ from typing import TYPE_CHECKING , Optional
44
55from langchain .chat_models import init_chat_model
66from langchain .output_parsers import (
@@ -30,7 +30,7 @@ class ChatModelLLM:
3030
3131 def __init__ (
3232 self ,
33- model_name : str ,
33+ model_name : Optional [ str ] = None ,
3434 response_schemas : list [ResponseSchema ] = None ,
3535 output_parser = StructuredOutputParser ,
3636 prompt_template = ChatPromptTemplate ,
@@ -39,9 +39,14 @@ def __init__(
3939 * args ,
4040 ** kwargs ,
4141 ):
42- self .model : BaseChatModel = init_chat_model (
43- model_name , max_retries = self .MAX_RETRIES , rate_limiter = self ._get_rate_limiter (), ** config
44- ) # llm model
42+ if settings .CUSTOM_LLM_INSTANCE :
43+ self .model : "BaseChatModel" = settings .CUSTOM_LLM_INSTANCE
44+ elif model_name :
45+ self .model : "BaseChatModel" = init_chat_model (
46+ model_name , max_retries = self .MAX_RETRIES , rate_limiter = self ._get_rate_limiter (), ** config
47+ )
48+ else :
49+ raise ValueError ("Either 'settings.CUSTOM_LLM_INSTANCE' must be set or 'LLM_PROVIDER' must be provided." )
4550
4651 self .parser : StructuredOutputParser = output_parser # the output parser
4752
@@ -135,6 +140,8 @@ def invoke(self, *args, **kwargs):
135140
136141 @classmethod
137142 def get_llm (cls , model_name : str , llm_config : dict = {}):
143+ if settings .CUSTOM_LLM_INSTANCE :
144+ return settings .CUSTOM_LLM_INSTANCE
138145 return init_chat_model (
139146 model_name , max_retries = cls .MAX_RETRIES , rate_limiter = cls ._get_rate_limiter (), ** llm_config
140147 )
0 commit comments