Skip to content

Commit 3e0bf46

Browse files
coderzcimbajin
andauthored
feat(llm): support litellm for multi-LLM provider (#178)
* Update README.md --------- Co-authored-by: imbajin <[email protected]>
1 parent e78792f commit 3e0bf46

File tree

11 files changed

+699
-149
lines changed

11 files changed

+699
-149
lines changed

hugegraph-llm/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
src/hugegraph_llm/resources/demo/questions_answers.xlsx
22
src/hugegraph_llm/resources/demo/questions.xlsx
33
src/hugegraph_llm/resources/backup-graph-data-4020/
4+
5+
uv.lock

hugegraph-llm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ graph systems and large language models.
6767
```bash
6868
python -m hugegraph_llm.config.generate --update
6969
```
70+
Note: `Litellm` support multi-LLM provider, refer [litellm.ai](https://docs.litellm.ai/docs/providers) to config it
7071
7. (__Optional__) You could use
7172
[hugegraph-hubble](https://hugegraph.apache.org/docs/quickstart/hugegraph-hubble/#21-use-docker-convenient-for-testdev)
7273
to visit the graph data, could run it via [Docker/Docker-Compose](https://hub.docker.com/r/hugegraph/hubble)

hugegraph-llm/poetry.lock

Lines changed: 302 additions & 126 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hugegraph-llm/pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ authors = [
2424
]
2525
readme = "README.md"
2626
license = "Apache-2.0"
27-
requires-python = "^3.10"
27+
requires-python = ">=3.10,<3.12"
2828
maintainers = [
2929
{ name = "Apache HugeGraph Contributors", email = "[email protected]" },
3030
]
@@ -38,7 +38,7 @@ documentation = "https://hugegraph.apache.org/docs/quickstart/hugegraph-ai/"
3838

3939
[tool.poetry.dependencies]
4040
python = "^3.10,<3.12"
41-
openai = "~1.47.1"
41+
openai = "~1.61.0"
4242
ollama = "~0.2.1"
4343
qianfan = "~0.3.18"
4444
retry = "~0.9.2"
@@ -61,6 +61,7 @@ setuptools = "~70.0.0"
6161
urllib3 = "~2.2.2"
6262
rich = "~13.9.4"
6363
apscheduler= "~3.10.4"
64+
litellm = "~1.61.13"
6465
hugegraph-python = { path = "../hugegraph-python-client/", develop = true }
6566

6667
[build-system]

hugegraph-llm/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
openai~=1.47.1
1+
openai~=1.61.0
22
ollama~=0.2.1
33
qianfan~=0.3.18
44
retry~=0.9.2
@@ -16,3 +16,4 @@ pandas~=2.2.2
1616
openpyxl~=3.1.5
1717
pydantic-settings~=2.6.1
1818
apscheduler~=3.10.4
19+
litellm~=1.61.13

hugegraph-llm/src/hugegraph_llm/config/llm_config.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
class LLMConfig(BaseConfig):
2626
"""LLM settings"""
2727

28-
chat_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
29-
extract_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
30-
text2gql_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
31-
embedding_type: Optional[Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"]] = "openai"
28+
chat_llm_type: Literal["openai", "litellm", "ollama/local", "qianfan_wenxin"] = "openai"
29+
extract_llm_type: Literal["openai", "litellm", "ollama/local", "qianfan_wenxin"] = "openai"
30+
text2gql_llm_type: Literal["openai", "litellm", "ollama/local", "qianfan_wenxin"] = "openai"
31+
embedding_type: Optional[Literal["openai", "litellm", "ollama/local", "qianfan_wenxin"]] = "openai"
3232
reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
3333
# 1. OpenAI settings
3434
openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
@@ -84,14 +84,19 @@ class LLMConfig(BaseConfig):
8484
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
8585
# refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details
8686
qianfan_embedding_model: Optional[str] = "embedding-v1"
87-
# TODO: To be confirmed, whether to configure
88-
# 5. ZhiPu(GLM) settings
89-
zhipu_chat_api_key: Optional[str] = None
90-
zhipu_chat_language_model: Optional[str] = "glm-4"
91-
zhipu_chat_embedding_model: Optional[str] = "embedding-2"
92-
zhipu_extract_api_key: Optional[str] = None
93-
zhipu_extract_language_model: Optional[str] = "glm-4"
94-
zhipu_extract_embedding_model: Optional[str] = "embedding-2"
95-
zhipu_text2gql_api_key: Optional[str] = None
96-
zhipu_text2gql_language_model: Optional[str] = "glm-4"
97-
zhipu_text2gql_embedding_model: Optional[str] = "embedding-2"
87+
# 5. LiteLLM settings
88+
litellm_chat_api_key: Optional[str] = None
89+
litellm_chat_api_base: Optional[str] = None
90+
litellm_chat_language_model: Optional[str] = "openai/gpt-4o"
91+
litellm_chat_tokens: int = 8192
92+
litellm_extract_api_key: Optional[str] = None
93+
litellm_extract_api_base: Optional[str] = None
94+
litellm_extract_language_model: Optional[str] = "openai/gpt-4o"
95+
litellm_extract_tokens: int = 256
96+
litellm_text2gql_api_key: Optional[str] = None
97+
litellm_text2gql_api_base: Optional[str] = None
98+
litellm_text2gql_language_model: Optional[str] = "openai/gpt-4o"
99+
litellm_text2gql_tokens: int = 4096
100+
litellm_embedding_api_key: Optional[str] = None
101+
litellm_embedding_api_base: Optional[str] = None
102+
litellm_embedding_model: Optional[str] = "openai/text-embedding-3-small"

hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,44 @@
2424
from requests.auth import HTTPBasicAuth
2525

2626
from hugegraph_llm.config import huge_settings, llm_settings
27+
from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding
28+
from hugegraph_llm.models.llms.litellm import LiteLLMClient
2729
from hugegraph_llm.utils.log import log
2830

2931
current_llm = "chat"
3032

3133

34+
def test_litellm_embedding(api_key, api_base, model_name) -> int:
35+
llm_client = LiteLLMEmbedding(
36+
api_key = api_key,
37+
api_base = api_base,
38+
model_name = model_name,
39+
)
40+
try:
41+
response = llm_client.get_text_embedding("test")
42+
assert len(response) > 0
43+
except Exception as e:
44+
raise gr.Error(f"Error in litellm embedding call: {e}") from e
45+
gr.Info("Test connection successful~")
46+
return 200
47+
48+
49+
def test_litellm_chat(api_key, api_base, model_name, max_tokens: int) -> int:
50+
try:
51+
llm_client = LiteLLMClient(
52+
api_key=api_key,
53+
api_base=api_base,
54+
model_name=model_name,
55+
max_tokens=max_tokens,
56+
)
57+
response = llm_client.generate(messages=[{"role": "user", "content": "hi"}])
58+
assert len(response) > 0
59+
except Exception as e:
60+
raise gr.Error(f"Error in litellm chat call: {e}") from e
61+
gr.Info("Test connection successful~")
62+
return 200
63+
64+
3265
def test_api_connection(url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None) -> int:
3366
# TODO: use fastapi.request / starlette instead?
3467
log.debug("Request URL: %s", url)
@@ -97,6 +130,11 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int:
97130
llm_settings.ollama_embedding_port = int(arg2)
98131
llm_settings.ollama_embedding_model = arg3
99132
status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call)
133+
elif embedding_option == "litellm":
134+
llm_settings.litellm_embedding_api_key = arg1
135+
llm_settings.litellm_embedding_api_base = arg2
136+
llm_settings.litellm_embedding_model = arg3
137+
status_code = test_litellm_embedding(arg1, arg2, arg3)
100138
llm_settings.update_env()
101139
gr.Info("Configured!")
102140
return status_code
@@ -173,7 +211,6 @@ def apply_llm_config(current_llm_config, arg1, arg2, arg3, arg4, origin_call=Non
173211
setattr(llm_settings, f"openai_{current_llm_config}_tokens", int(arg4))
174212

175213
test_url = getattr(llm_settings, f"openai_{current_llm_config}_api_base") + "/chat/completions"
176-
log.debug("Type of OpenAI %s max_token is %s", current_llm_config, type(arg4))
177214
data = {
178215
"model": arg3,
179216
"temperature": 0.0,
@@ -192,6 +229,14 @@ def apply_llm_config(current_llm_config, arg1, arg2, arg3, arg4, origin_call=Non
192229
setattr(llm_settings, f"ollama_{current_llm_config}_language_model", arg3)
193230
status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call)
194231

232+
elif llm_option == "litellm":
233+
setattr(llm_settings, f"litellm_{current_llm_config}_api_key", arg1)
234+
setattr(llm_settings, f"litellm_{current_llm_config}_api_base", arg2)
235+
setattr(llm_settings, f"litellm_{current_llm_config}_language_model", arg3)
236+
setattr(llm_settings, f"litellm_{current_llm_config}_tokens", int(arg4))
237+
238+
status_code = test_litellm_chat(arg1, arg2, arg3, int(arg4))
239+
195240
gr.Info("Configured!")
196241
llm_settings.update_env()
197242
return status_code
@@ -218,7 +263,7 @@ def create_configs_block() -> list:
218263
with gr.Accordion("2. Set up the LLM.", open=False):
219264
gr.Markdown("> Tips: the openai option also support openai style api from other providers.")
220265
with gr.Tab(label='chat'):
221-
chat_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"],
266+
chat_llm_dropdown = gr.Dropdown(choices=["openai", "litellm", "qianfan_wenxin", "ollama/local"],
222267
value=getattr(llm_settings, "chat_llm_type"), label="type")
223268
apply_llm_config_with_chat_op = partial(apply_llm_config, "chat")
224269

@@ -249,13 +294,23 @@ def chat_llm_settings(llm_type):
249294
gr.Textbox(value=getattr(llm_settings, "qianfan_chat_language_model"), label="model_name"),
250295
gr.Textbox(value="", visible=False),
251296
]
297+
elif llm_type == "litellm":
298+
llm_config_input = [
299+
gr.Textbox(value=getattr(llm_settings, "litellm_chat_api_key"), label="api_key",
300+
type="password"),
301+
gr.Textbox(value=getattr(llm_settings, "litellm_chat_api_base"), label="api_base",
302+
info="If you want to use the default api_base, please keep it blank"),
303+
gr.Textbox(value=getattr(llm_settings, "litellm_chat_language_model"), label="model_name",
304+
info="Please refer to https://docs.litellm.ai/docs/providers"),
305+
gr.Textbox(value=getattr(llm_settings, "litellm_chat_tokens"), label="max_token"),
306+
]
252307
else:
253308
llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)]
254309
llm_config_button = gr.Button("Apply configuration")
255310
llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input)
256311

257312
with gr.Tab(label='mini_tasks'):
258-
extract_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"],
313+
extract_llm_dropdown = gr.Dropdown(choices=["openai", "litellm", "qianfan_wenxin", "ollama/local"],
259314
value=getattr(llm_settings, "extract_llm_type"), label="type")
260315
apply_llm_config_with_extract_op = partial(apply_llm_config, "extract")
261316

@@ -286,12 +341,22 @@ def extract_llm_settings(llm_type):
286341
gr.Textbox(value=getattr(llm_settings, "qianfan_extract_language_model"), label="model_name"),
287342
gr.Textbox(value="", visible=False),
288343
]
344+
elif llm_type == "litellm":
345+
llm_config_input = [
346+
gr.Textbox(value=getattr(llm_settings, "litellm_extract_api_key"), label="api_key",
347+
type="password"),
348+
gr.Textbox(value=getattr(llm_settings, "litellm_extract_api_base"), label="api_base",
349+
info="If you want to use the default api_base, please keep it blank"),
350+
gr.Textbox(value=getattr(llm_settings, "litellm_extract_language_model"), label="model_name",
351+
info="Please refer to https://docs.litellm.ai/docs/providers"),
352+
gr.Textbox(value=getattr(llm_settings, "litellm_extract_tokens"), label="max_token"),
353+
]
289354
else:
290355
llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)]
291356
llm_config_button = gr.Button("Apply configuration")
292357
llm_config_button.click(apply_llm_config_with_extract_op, inputs=llm_config_input)
293358
with gr.Tab(label='text2gql'):
294-
text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"],
359+
text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "litellm", "qianfan_wenxin", "ollama/local"],
295360
value=getattr(llm_settings, "text2gql_llm_type"), label="type")
296361
apply_llm_config_with_text2gql_op = partial(apply_llm_config, "text2gql")
297362

@@ -322,14 +387,25 @@ def text2gql_llm_settings(llm_type):
322387
gr.Textbox(value=getattr(llm_settings, "qianfan_text2gql_language_model"), label="model_name"),
323388
gr.Textbox(value="", visible=False),
324389
]
390+
elif llm_type == "litellm":
391+
llm_config_input = [
392+
gr.Textbox(value=getattr(llm_settings, "litellm_text2gql_api_key"), label="api_key",
393+
type="password"),
394+
gr.Textbox(value=getattr(llm_settings, "litellm_text2gql_api_base"), label="api_base",
395+
info="If you want to use the default api_base, please keep it blank"),
396+
gr.Textbox(value=getattr(llm_settings, "litellm_text2gql_language_model"), label="model_name",
397+
info="Please refer to https://docs.litellm.ai/docs/providers"),
398+
gr.Textbox(value=getattr(llm_settings, "litellm_text2gql_tokens"), label="max_token"),
399+
]
325400
else:
326401
llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)]
327402
llm_config_button = gr.Button("Apply configuration")
328403
llm_config_button.click(apply_llm_config_with_text2gql_op, inputs=llm_config_input)
329404

330405
with gr.Accordion("3. Set up the Embedding.", open=False):
331406
embedding_dropdown = gr.Dropdown(
332-
choices=["openai", "qianfan_wenxin", "ollama/local"], value=llm_settings.embedding_type, label="Embedding"
407+
choices=["openai", "litellm", "qianfan_wenxin", "ollama/local"], value=llm_settings.embedding_type,
408+
label="Embedding"
333409
)
334410

335411
@gr.render(inputs=[embedding_dropdown])
@@ -357,6 +433,16 @@ def embedding_settings(embedding_type):
357433
type="password"),
358434
gr.Textbox(value=llm_settings.qianfan_embedding_model, label="model_name"),
359435
]
436+
elif embedding_type == "litellm":
437+
with gr.Row():
438+
embedding_config_input = [
439+
gr.Textbox(value=getattr(llm_settings, "litellm_embedding_api_key"), label="api_key",
440+
type="password"),
441+
gr.Textbox(value=getattr(llm_settings, "litellm_embedding_api_base"), label="api_base",
442+
info="If you want to use the default api_base, please keep it blank"),
443+
gr.Textbox(value=getattr(llm_settings, "litellm_embedding_model"), label="model_name",
444+
info="Please refer to https://docs.litellm.ai/docs/embedding/supported_embedding"),
445+
]
360446
else:
361447
embedding_config_input = [
362448
gr.Textbox(value="", visible=False),

hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding
2020
from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding
2121
from hugegraph_llm.models.embeddings.qianfan import QianFanEmbedding
22+
from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding
2223
from hugegraph_llm.config import llm_settings
2324

2425

@@ -45,5 +46,11 @@ def get_embedding(self):
4546
api_key=llm_settings.qianfan_embedding_api_key,
4647
secret_key=llm_settings.qianfan_embedding_secret_key
4748
)
49+
if self.embedding_type == "litellm":
50+
return LiteLLMEmbedding(
51+
model_name=llm_settings.litellm_embedding_model,
52+
api_key=llm_settings.litellm_embedding_api_key,
53+
api_base=llm_settings.litellm_embedding_api_base
54+
)
4855

4956
raise Exception("embedding type is not supported !")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import List, Optional
19+
20+
from litellm import embedding, RateLimitError, APIError, APIConnectionError, aembedding
21+
from tenacity import (
22+
retry,
23+
stop_after_attempt,
24+
wait_exponential,
25+
retry_if_exception_type,
26+
)
27+
28+
from hugegraph_llm.models.embeddings.base import BaseEmbedding
29+
from hugegraph_llm.utils.log import log
30+
31+
32+
class LiteLLMEmbedding(BaseEmbedding):
33+
"""Wrapper for LiteLLM Embedding that supports multiple LLM providers."""
34+
35+
def __init__(
36+
self,
37+
api_key: Optional[str] = None,
38+
api_base: Optional[str] = None,
39+
model_name: str = "openai/text-embedding-3-small", # Can be any embedding model supported by LiteLLM
40+
) -> None:
41+
self.api_key = api_key
42+
self.api_base = api_base
43+
self.model = model_name
44+
45+
@retry(
46+
stop=stop_after_attempt(3),
47+
wait=wait_exponential(multiplier=1, min=4, max=10),
48+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APIError)),
49+
)
50+
def get_text_embedding(self, text: str) -> List[float]:
51+
"""Get embedding for a single text."""
52+
try:
53+
response = embedding(
54+
model=self.model,
55+
input=text,
56+
api_key=self.api_key,
57+
api_base=self.api_base,
58+
)
59+
log.info("Token usage: %s", response.usage)
60+
return response.data[0]["embedding"]
61+
except (RateLimitError, APIConnectionError, APIError) as e:
62+
log.error("Error in LiteLLM embedding call: %s", e)
63+
raise
64+
65+
def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]:
66+
"""Get embeddings for multiple texts."""
67+
try:
68+
response = embedding(
69+
model=self.model,
70+
input=texts,
71+
api_key=self.api_key,
72+
api_base=self.api_base,
73+
)
74+
log.info("Token usage: %s", response.usage)
75+
return [data["embedding"] for data in response.data]
76+
except (RateLimitError, APIConnectionError, APIError) as e:
77+
log.error("Error in LiteLLM batch embedding call: %s", e)
78+
raise
79+
80+
async def async_get_text_embedding(self, text: str) -> List[float]:
81+
"""Get embedding for a single text asynchronously."""
82+
try:
83+
response = await aembedding(
84+
model=self.model,
85+
input=text,
86+
api_key=self.api_key,
87+
api_base=self.api_base,
88+
)
89+
log.info("Token usage: %s", response.usage)
90+
return response.data[0]["embedding"]
91+
except (RateLimitError, APIConnectionError, APIError) as e:
92+
log.error("Error in async LiteLLM embedding call: %s", e)
93+
raise

0 commit comments

Comments
 (0)