2424from requests .auth import HTTPBasicAuth
2525
2626from hugegraph_llm .config import huge_settings , llm_settings
27+ from hugegraph_llm .models .embeddings .litellm import LiteLLMEmbedding
28+ from hugegraph_llm .models .llms .litellm import LiteLLMClient
2729from hugegraph_llm .utils .log import log
2830
2931current_llm = "chat"
3032
3133
34+ def test_litellm_embedding (api_key , api_base , model_name ) -> int :
35+ llm_client = LiteLLMEmbedding (
36+ api_key = api_key ,
37+ api_base = api_base ,
38+ model_name = model_name ,
39+ )
40+ try :
41+ response = llm_client .get_text_embedding ("test" )
42+ assert len (response ) > 0
43+ except Exception as e :
44+ raise gr .Error (f"Error in litellm embedding call: { e } " ) from e
45+ gr .Info ("Test connection successful~" )
46+ return 200
47+
48+
49+ def test_litellm_chat (api_key , api_base , model_name , max_tokens : int ) -> int :
50+ try :
51+ llm_client = LiteLLMClient (
52+ api_key = api_key ,
53+ api_base = api_base ,
54+ model_name = model_name ,
55+ max_tokens = max_tokens ,
56+ )
57+ response = llm_client .generate (messages = [{"role" : "user" , "content" : "hi" }])
58+ assert len (response ) > 0
59+ except Exception as e :
60+ raise gr .Error (f"Error in litellm chat call: { e } " ) from e
61+ gr .Info ("Test connection successful~" )
62+ return 200
63+
64+
3265def test_api_connection (url , method = "GET" , headers = None , params = None , body = None , auth = None , origin_call = None ) -> int :
3366 # TODO: use fastapi.request / starlette instead?
3467 log .debug ("Request URL: %s" , url )
@@ -97,6 +130,11 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
97130 llm_settings .ollama_embedding_port = int (arg2 )
98131 llm_settings .ollama_embedding_model = arg3
99132 status_code = test_api_connection (f"http://{ arg1 } :{ arg2 } " , origin_call = origin_call )
133+ elif embedding_option == "litellm" :
134+ llm_settings .litellm_embedding_api_key = arg1
135+ llm_settings .litellm_embedding_api_base = arg2
136+ llm_settings .litellm_embedding_model = arg3
137+ status_code = test_litellm_embedding (arg1 , arg2 , arg3 )
100138 llm_settings .update_env ()
101139 gr .Info ("Configured!" )
102140 return status_code
@@ -173,7 +211,6 @@ def apply_llm_config(current_llm_config, arg1, arg2, arg3, arg4, origin_call=Non
173211 setattr (llm_settings , f"openai_{ current_llm_config } _tokens" , int (arg4 ))
174212
175213 test_url = getattr (llm_settings , f"openai_{ current_llm_config } _api_base" ) + "/chat/completions"
176- log .debug ("Type of OpenAI %s max_token is %s" , current_llm_config , type (arg4 ))
177214 data = {
178215 "model" : arg3 ,
179216 "temperature" : 0.0 ,
@@ -192,6 +229,14 @@ def apply_llm_config(current_llm_config, arg1, arg2, arg3, arg4, origin_call=Non
192229 setattr (llm_settings , f"ollama_{ current_llm_config } _language_model" , arg3 )
193230 status_code = test_api_connection (f"http://{ arg1 } :{ arg2 } " , origin_call = origin_call )
194231
232+ elif llm_option == "litellm" :
233+ setattr (llm_settings , f"litellm_{ current_llm_config } _api_key" , arg1 )
234+ setattr (llm_settings , f"litellm_{ current_llm_config } _api_base" , arg2 )
235+ setattr (llm_settings , f"litellm_{ current_llm_config } _language_model" , arg3 )
236+ setattr (llm_settings , f"litellm_{ current_llm_config } _tokens" , int (arg4 ))
237+
238+ status_code = test_litellm_chat (arg1 , arg2 , arg3 , int (arg4 ))
239+
195240 gr .Info ("Configured!" )
196241 llm_settings .update_env ()
197242 return status_code
@@ -218,7 +263,7 @@ def create_configs_block() -> list:
218263 with gr .Accordion ("2. Set up the LLM." , open = False ):
219264 gr .Markdown ("> Tips: the openai option also support openai style api from other providers." )
220265 with gr .Tab (label = 'chat' ):
221- chat_llm_dropdown = gr .Dropdown (choices = ["openai" , "qianfan_wenxin" , "ollama/local" ],
266+ chat_llm_dropdown = gr .Dropdown (choices = ["openai" , "litellm" , " qianfan_wenxin" , "ollama/local" ],
222267 value = getattr (llm_settings , "chat_llm_type" ), label = "type" )
223268 apply_llm_config_with_chat_op = partial (apply_llm_config , "chat" )
224269
@@ -249,13 +294,23 @@ def chat_llm_settings(llm_type):
249294 gr .Textbox (value = getattr (llm_settings , "qianfan_chat_language_model" ), label = "model_name" ),
250295 gr .Textbox (value = "" , visible = False ),
251296 ]
297+ elif llm_type == "litellm" :
298+ llm_config_input = [
299+ gr .Textbox (value = getattr (llm_settings , "litellm_chat_api_key" ), label = "api_key" ,
300+ type = "password" ),
301+ gr .Textbox (value = getattr (llm_settings , "litellm_chat_api_base" ), label = "api_base" ,
302+ info = "If you want to use the default api_base, please keep it blank" ),
303+ gr .Textbox (value = getattr (llm_settings , "litellm_chat_language_model" ), label = "model_name" ,
304+ info = "Please refer to https://docs.litellm.ai/docs/providers" ),
305+ gr .Textbox (value = getattr (llm_settings , "litellm_chat_tokens" ), label = "max_token" ),
306+ ]
252307 else :
253308 llm_config_input = [gr .Textbox (value = "" , visible = False ) for _ in range (4 )]
254309 llm_config_button = gr .Button ("Apply configuration" )
255310 llm_config_button .click (apply_llm_config_with_chat_op , inputs = llm_config_input )
256311
257312 with gr .Tab (label = 'mini_tasks' ):
258- extract_llm_dropdown = gr .Dropdown (choices = ["openai" , "qianfan_wenxin" , "ollama/local" ],
313+ extract_llm_dropdown = gr .Dropdown (choices = ["openai" , "litellm" , " qianfan_wenxin" , "ollama/local" ],
259314 value = getattr (llm_settings , "extract_llm_type" ), label = "type" )
260315 apply_llm_config_with_extract_op = partial (apply_llm_config , "extract" )
261316
@@ -286,12 +341,22 @@ def extract_llm_settings(llm_type):
286341 gr .Textbox (value = getattr (llm_settings , "qianfan_extract_language_model" ), label = "model_name" ),
287342 gr .Textbox (value = "" , visible = False ),
288343 ]
344+ elif llm_type == "litellm" :
345+ llm_config_input = [
346+ gr .Textbox (value = getattr (llm_settings , "litellm_extract_api_key" ), label = "api_key" ,
347+ type = "password" ),
348+ gr .Textbox (value = getattr (llm_settings , "litellm_extract_api_base" ), label = "api_base" ,
349+ info = "If you want to use the default api_base, please keep it blank" ),
350+ gr .Textbox (value = getattr (llm_settings , "litellm_extract_language_model" ), label = "model_name" ,
351+ info = "Please refer to https://docs.litellm.ai/docs/providers" ),
352+ gr .Textbox (value = getattr (llm_settings , "litellm_extract_tokens" ), label = "max_token" ),
353+ ]
289354 else :
290355 llm_config_input = [gr .Textbox (value = "" , visible = False ) for _ in range (4 )]
291356 llm_config_button = gr .Button ("Apply configuration" )
292357 llm_config_button .click (apply_llm_config_with_extract_op , inputs = llm_config_input )
293358 with gr .Tab (label = 'text2gql' ):
294- text2gql_llm_dropdown = gr .Dropdown (choices = ["openai" , "qianfan_wenxin" , "ollama/local" ],
359+ text2gql_llm_dropdown = gr .Dropdown (choices = ["openai" , "litellm" , " qianfan_wenxin" , "ollama/local" ],
295360 value = getattr (llm_settings , "text2gql_llm_type" ), label = "type" )
296361 apply_llm_config_with_text2gql_op = partial (apply_llm_config , "text2gql" )
297362
@@ -322,14 +387,25 @@ def text2gql_llm_settings(llm_type):
322387 gr .Textbox (value = getattr (llm_settings , "qianfan_text2gql_language_model" ), label = "model_name" ),
323388 gr .Textbox (value = "" , visible = False ),
324389 ]
390+ elif llm_type == "litellm" :
391+ llm_config_input = [
392+ gr .Textbox (value = getattr (llm_settings , "litellm_text2gql_api_key" ), label = "api_key" ,
393+ type = "password" ),
394+ gr .Textbox (value = getattr (llm_settings , "litellm_text2gql_api_base" ), label = "api_base" ,
395+ info = "If you want to use the default api_base, please keep it blank" ),
396+ gr .Textbox (value = getattr (llm_settings , "litellm_text2gql_language_model" ), label = "model_name" ,
397+ info = "Please refer to https://docs.litellm.ai/docs/providers" ),
398+ gr .Textbox (value = getattr (llm_settings , "litellm_text2gql_tokens" ), label = "max_token" ),
399+ ]
325400 else :
326401 llm_config_input = [gr .Textbox (value = "" , visible = False ) for _ in range (4 )]
327402 llm_config_button = gr .Button ("Apply configuration" )
328403 llm_config_button .click (apply_llm_config_with_text2gql_op , inputs = llm_config_input )
329404
330405 with gr .Accordion ("3. Set up the Embedding." , open = False ):
331406 embedding_dropdown = gr .Dropdown (
332- choices = ["openai" , "qianfan_wenxin" , "ollama/local" ], value = llm_settings .embedding_type , label = "Embedding"
407+ choices = ["openai" , "litellm" , "qianfan_wenxin" , "ollama/local" ], value = llm_settings .embedding_type ,
408+ label = "Embedding"
333409 )
334410
335411 @gr .render (inputs = [embedding_dropdown ])
@@ -357,6 +433,16 @@ def embedding_settings(embedding_type):
357433 type = "password" ),
358434 gr .Textbox (value = llm_settings .qianfan_embedding_model , label = "model_name" ),
359435 ]
436+ elif embedding_type == "litellm" :
437+ with gr .Row ():
438+ embedding_config_input = [
439+ gr .Textbox (value = getattr (llm_settings , "litellm_embedding_api_key" ), label = "api_key" ,
440+ type = "password" ),
441+ gr .Textbox (value = getattr (llm_settings , "litellm_embedding_api_base" ), label = "api_base" ,
442+ info = "If you want to use the default api_base, please keep it blank" ),
443+ gr .Textbox (value = getattr (llm_settings , "litellm_embedding_model" ), label = "model_name" ,
444+ info = "Please refer to https://docs.litellm.ai/docs/embedding/supported_embedding" ),
445+ ]
360446 else :
361447 embedding_config_input = [
362448 gr .Textbox (value = "" , visible = False ),
0 commit comments