|
| 1 | +import os |
| 2 | + |
| 3 | +import vertexai |
| 4 | +from langchain_aws import ChatBedrock |
| 5 | +from langchain_cohere import ChatCohere |
| 6 | +from langchain_google_vertexai import ChatVertexAI |
| 7 | +from langchain_mistralai import ChatMistralAI |
| 8 | +from langchain_openai import AzureChatOpenAI, ChatOpenAI |
| 9 | + |
| 10 | +LLM_TYPE = os.getenv("LLM_TYPE", "openai") |
| 11 | + |
| 12 | + |
| 13 | +def init_openai_chat(temperature): |
| 14 | + return ChatOpenAI(model=os.getenv("CHAT_MODEL"), streaming=True, temperature=temperature) |
| 15 | + |
| 16 | + |
| 17 | +def init_vertex_chat(temperature): |
| 18 | + VERTEX_PROJECT_ID = os.getenv("VERTEX_PROJECT_ID") |
| 19 | + VERTEX_REGION = os.getenv("VERTEX_REGION", "us-central1") |
| 20 | + vertexai.init(project=VERTEX_PROJECT_ID, location=VERTEX_REGION) |
| 21 | + return ChatVertexAI(streaming=True, temperature=temperature) |
| 22 | + |
| 23 | + |
| 24 | +def init_azure_chat(temperature): |
| 25 | + return AzureChatOpenAI(model=os.getenv("CHAT_DEPLOYMENT"), streaming=True, temperature=temperature) |
| 26 | + |
| 27 | + |
| 28 | +def init_bedrock(temperature): |
| 29 | + return ChatBedrock(model_id=os.getenv("CHAT_MODEL"), streaming=True, model_kwargs={"temperature": temperature}) |
| 30 | + |
| 31 | + |
| 32 | +def init_mistral_chat(temperature): |
| 33 | + MISTRAL_API_ENDPOINT = os.getenv("MISTRAL_API_ENDPOINT") |
| 34 | + MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
| 35 | + MISTRAL_MODEL = os.getenv("MISTRAL_MODEL", "Mistral-large") |
| 36 | + kwargs = { |
| 37 | + "mistral_api_key": MISTRAL_API_KEY, |
| 38 | + "temperature": temperature, |
| 39 | + } |
| 40 | + if MISTRAL_API_ENDPOINT: |
| 41 | + kwargs["endpoint"] = MISTRAL_API_ENDPOINT |
| 42 | + if MISTRAL_MODEL: |
| 43 | + kwargs["model"] = MISTRAL_MODEL |
| 44 | + return ChatMistralAI(**kwargs) |
| 45 | + |
| 46 | + |
| 47 | +def init_cohere_chat(temperature): |
| 48 | + COHERE_API_KEY = os.getenv("COHERE_API_KEY") |
| 49 | + COHERE_MODEL = os.getenv("COHERE_MODEL") |
| 50 | + return ChatCohere(cohere_api_key=COHERE_API_KEY, model=COHERE_MODEL, temperature=temperature) |
| 51 | + |
| 52 | + |
| 53 | +MAP_LLM_TYPE_TO_CHAT_MODEL = { |
| 54 | + "azure": init_azure_chat, |
| 55 | + "bedrock": init_bedrock, |
| 56 | + "openai": init_openai_chat, |
| 57 | + "vertex": init_vertex_chat, |
| 58 | + "mistral": init_mistral_chat, |
| 59 | + "cohere": init_cohere_chat, |
| 60 | +} |
| 61 | + |
| 62 | + |
| 63 | +def get_llm(temperature=0): |
| 64 | + if LLM_TYPE not in MAP_LLM_TYPE_TO_CHAT_MODEL: |
| 65 | + raise Exception( |
| 66 | + "LLM type not found. Please set LLM_TYPE to one of: " + ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys()) + "." |
| 67 | + ) |
| 68 | + |
| 69 | + return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE](temperature=temperature) |
0 commit comments