|
11 | 11 | from typing import Dict, Optional
|
12 | 12 |
|
13 | 13 | from langchain_anthropic import ChatAnthropic
|
| 14 | +from langchain_mistralai import ChatMistralAI |
14 | 15 | from langchain_google_genai import ChatGoogleGenerativeAI
|
15 | 16 | from langchain_ollama import ChatOllama
|
16 | 17 | from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
@@ -40,6 +41,22 @@ def get_llm_model(provider: str, **kwargs):
|
40 | 41 | base_url=base_url,
|
41 | 42 | api_key=api_key,
|
42 | 43 | )
|
| 44 | + elif provider == 'mistral': |
| 45 | + if not kwargs.get("base_url", ""): |
| 46 | + base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") |
| 47 | + else: |
| 48 | + base_url = kwargs.get("base_url") |
| 49 | + if not kwargs.get("api_key", ""): |
| 50 | + api_key = os.getenv("MISTRAL_API_KEY", "") |
| 51 | + else: |
| 52 | + api_key = kwargs.get("api_key") |
| 53 | + |
| 54 | + return ChatMistralAI( |
| 55 | + model=kwargs.get("model_name", "mistral-large-latest"), |
| 56 | + temperature=kwargs.get("temperature", 0.0), |
| 57 | + base_url=base_url, |
| 58 | + api_key=api_key, |
| 59 | + ) |
43 | 60 | elif provider == "openai":
|
44 | 61 | if not kwargs.get("base_url", ""):
|
45 | 62 | base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
|
@@ -117,7 +134,8 @@ def get_llm_model(provider: str, **kwargs):
|
117 | 134 | "deepseek": ["deepseek-chat"],
|
118 | 135 | "gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
|
119 | 136 | "ollama": ["qwen2.5:7b", "llama2:7b"],
|
120 |
| - "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"] |
| 137 | + "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"], |
| 138 | + "mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"] |
121 | 139 | }
|
122 | 140 |
|
123 | 141 | # Callback to update the model name dropdown based on the selected provider
|
|
0 commit comments