12
12
13
13
from .llm import DeepSeekR1ChatOpenAI , DeepSeekR1ChatOllama
14
14
15
+ PROVIDER_DISPLAY_NAMES = {
16
+ "openai" : "OpenAI" ,
17
+ "azure_openai" : "Azure OpenAI" ,
18
+ "anthropic" : "Anthropic" ,
19
+ "deepseek" : "DeepSeek" ,
20
+ "gemini" : "Gemini"
21
+ }
22
+
15
23
def get_llm_model (provider : str , ** kwargs ):
16
24
"""
17
25
获取LLM 模型
18
26
:param provider: 模型类型
19
27
:param kwargs:
20
28
:return:
21
29
"""
30
+ if provider not in ["ollama" ]:
31
+ env_var = "GOOGLE_API_KEY" if provider == "gemini" else f"{ provider .upper ()} _API_KEY"
32
+ api_key = kwargs .get ("api_key" , "" ) or os .getenv (env_var , "" )
33
+ if not api_key :
34
+ handle_api_key_error (provider , env_var )
35
+ kwargs ["api_key" ] = api_key
36
+
22
37
if provider == "anthropic" :
23
38
if not kwargs .get ("base_url" , "" ):
24
39
base_url = "https://api.anthropic.com"
25
40
else :
26
41
base_url = kwargs .get ("base_url" )
27
42
28
- if not kwargs .get ("api_key" , "" ):
29
- api_key = os .getenv ("ANTHROPIC_API_KEY" , "" )
30
- else :
31
- api_key = kwargs .get ("api_key" )
32
-
33
43
return ChatAnthropic (
34
44
model_name = kwargs .get ("model_name" , "claude-3-5-sonnet-20240620" ),
35
45
temperature = kwargs .get ("temperature" , 0.0 ),
@@ -42,11 +52,6 @@ def get_llm_model(provider: str, **kwargs):
42
52
else :
43
53
base_url = kwargs .get ("base_url" )
44
54
45
- if not kwargs .get ("api_key" , "" ):
46
- api_key = os .getenv ("OPENAI_API_KEY" , "" )
47
- else :
48
- api_key = kwargs .get ("api_key" )
49
-
50
55
return ChatOpenAI (
51
56
model = kwargs .get ("model_name" , "gpt-4o" ),
52
57
temperature = kwargs .get ("temperature" , 0.0 ),
@@ -59,11 +64,6 @@ def get_llm_model(provider: str, **kwargs):
59
64
else :
60
65
base_url = kwargs .get ("base_url" )
61
66
62
- if not kwargs .get ("api_key" , "" ):
63
- api_key = os .getenv ("DEEPSEEK_API_KEY" , "" )
64
- else :
65
- api_key = kwargs .get ("api_key" )
66
-
67
67
if kwargs .get ("model_name" , "deepseek-chat" ) == "deepseek-reasoner" :
68
68
return DeepSeekR1ChatOpenAI (
69
69
model = kwargs .get ("model_name" , "deepseek-reasoner" ),
@@ -79,10 +79,6 @@ def get_llm_model(provider: str, **kwargs):
79
79
api_key = api_key ,
80
80
)
81
81
elif provider == "gemini" :
82
- if not kwargs .get ("api_key" , "" ):
83
- api_key = os .getenv ("GOOGLE_API_KEY" , "" )
84
- else :
85
- api_key = kwargs .get ("api_key" )
86
82
return ChatGoogleGenerativeAI (
87
83
model = kwargs .get ("model_name" , "gemini-2.0-flash-exp" ),
88
84
temperature = kwargs .get ("temperature" , 0.0 ),
@@ -114,10 +110,6 @@ def get_llm_model(provider: str, **kwargs):
114
110
base_url = os .getenv ("AZURE_OPENAI_ENDPOINT" , "" )
115
111
else :
116
112
base_url = kwargs .get ("base_url" )
117
- if not kwargs .get ("api_key" , "" ):
118
- api_key = os .getenv ("AZURE_OPENAI_API_KEY" , "" )
119
- else :
120
- api_key = kwargs .get ("api_key" )
121
113
return AzureChatOpenAI (
122
114
model = kwargs .get ("model_name" , "gpt-4o" ),
123
115
temperature = kwargs .get ("temperature" , 0.0 ),
@@ -154,7 +146,17 @@ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
154
146
return gr .Dropdown (choices = model_names [llm_provider ], value = model_names [llm_provider ][0 ], interactive = True )
155
147
else :
156
148
return gr .Dropdown (choices = [], value = "" , interactive = True , allow_custom_value = True )
157
-
149
+
150
+ def handle_api_key_error (provider : str , env_var : str ):
151
+ """
152
+ Handles the missing API key error by raising a gr.Error with a clear message.
153
+ """
154
+ provider_display = PROVIDER_DISPLAY_NAMES .get (provider , provider .upper ())
155
+ raise gr .Error (
156
+ f"💥 { provider_display } API key not found! 🔑 Please set the "
157
+ f"`{ env_var } ` environment variable or provide it in the UI."
158
+ )
159
+
158
160
def encode_image (img_path ):
159
161
if not img_path :
160
162
return None
0 commit comments