Skip to content

Commit fa7e168

Browse files
authored
Merge pull request #143 from pad918/main
Features/mistralai intergrate
2 parents 80d2d1a + c0c2545 commit fa7e168

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ browser-use==0.1.29
22
pyperclip==1.9.0
33
gradio==5.10.0
44
json-repair
5+
langchain-mistralai==0.2.4

src/utils/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, Optional
66

77
from langchain_anthropic import ChatAnthropic
8+
from langchain_mistralai import ChatMistralAI
89
from langchain_google_genai import ChatGoogleGenerativeAI
910
from langchain_ollama import ChatOllama
1011
from langchain_openai import AzureChatOpenAI, ChatOpenAI
@@ -46,6 +47,22 @@ def get_llm_model(provider: str, **kwargs):
4647
base_url=base_url,
4748
api_key=api_key,
4849
)
50+
elif provider == 'mistral':
51+
if not kwargs.get("base_url", ""):
52+
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
53+
else:
54+
base_url = kwargs.get("base_url")
55+
if not kwargs.get("api_key", ""):
56+
api_key = os.getenv("MISTRAL_API_KEY", "")
57+
else:
58+
api_key = kwargs.get("api_key")
59+
60+
return ChatMistralAI(
61+
model=kwargs.get("model_name", "mistral-large-latest"),
62+
temperature=kwargs.get("temperature", 0.0),
63+
base_url=base_url,
64+
api_key=api_key,
65+
)
4966
elif provider == "openai":
5067
if not kwargs.get("base_url", ""):
5168
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
@@ -127,7 +144,8 @@ def get_llm_model(provider: str, **kwargs):
127144
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
128145
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
129146
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
130-
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
147+
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
148+
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"]
131149
}
132150

133151
# Callback to update the model name dropdown based on the selected provider

tests/test_llm_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def get_env_value(key, provider):
3838
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
3939
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
4040
"gemini": {"api_key": "GOOGLE_API_KEY"},
41-
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}
41+
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"},
42+
"mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"},
4243
}
4344

4445
if provider in env_mappings and key in env_mappings[provider]:
@@ -116,11 +117,16 @@ def test_deepseek_r1_ollama_model():
116117
config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b")
117118
test_llm(config, "How many 'r's are in the word 'strawberry'?")
118119

120+
def test_mistral_model():
121+
config = LLMConfig(provider="mistral", model_name="pixtral-large-latest")
122+
test_llm(config, "Describe this image", "assets/examples/test.png")
123+
119124
if __name__ == "__main__":
120125
# test_openai_model()
121126
# test_gemini_model()
122127
# test_azure_openai_model()
123-
test_deepseek_model()
128+
#test_deepseek_model()
124129
# test_ollama_model()
125130
# test_deepseek_r1_model()
126131
# test_deepseek_r1_ollama_model()
132+
test_mistral_model()

0 commit comments

Comments
 (0)