Skip to content

Commit 793c6ba

Browse files
committed
[feat][add local fastapi server]
1 parent 1972658 commit 793c6ba

File tree

14 files changed

+652
-358
lines changed

14 files changed

+652
-358
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ build
1515
*egg-info
1616
dist
1717
.ipynb_checkpoints
18-
zdatafront*
18+
zdatafront*
19+
*antgroup*

examples/ekg_examples/ekg.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# geabase config
2+
geabase_config:
3+
metaserver_address: 'deafault'
4+
project: 'deafault'
5+
city: 'deafault'
6+
lib_path: 'deafault'
7+
8+
9+
# nebula config
10+
nebula_config:
11+
host: 'localhost'
12+
port: 7070
13+
username: 'default'
14+
password: 'default'
15+
space_name: 'default'
16+
17+
# tbase config
18+
tbase_config:
19+
host: 'localhost'
20+
port: 321321
21+
username: 'default'
22+
password: ''
23+
definition_value: 'opsgptkg'
24+
25+
26+
# model
27+
llm:
28+
model_type: 'openai'
29+
model_name: 'default'
30+
stop: ''
31+
temperature: 0.3
32+
top_p: 0.95
33+
top_k: 50
34+
url: ''
35+
token: ''
36+
37+
38+
# embedding
39+
embedding:
40+
embedding_type: 'openai'
41+
model_name: 'default'
42+
url: ''
43+
token: ''

examples/ekg_examples/start.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import time, sys
2+
st = time.time()
3+
import os
4+
import yaml
5+
import requests
6+
from typing import List
7+
from loguru import logger
8+
import tqdm
9+
from concurrent.futures import ThreadPoolExecutor
10+
11+
print(time.time()-st)
12+
from langchain.llms.base import LLM
13+
from langchain.embeddings.base import Embeddings
14+
print(time.time()-st)
15+
src_dir = os.path.join(
16+
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17+
)
18+
print(src_dir)
19+
sys.path.append(src_dir)
20+
21+
from muagent.schemas.db import *
22+
from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
23+
from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
24+
25+
from pydantic import BaseModel
26+
27+
# llm config
28+
class CustomLLM(LLM, BaseModel):
29+
url: str = "http://localhost:11434/api/generate"
30+
model_name: str = "qwen2:1b"
31+
model_type: str = "ollama"
32+
api_key: str = ""
33+
stop: str = ""
34+
temperature: float = 0.3
35+
top_k: int = 50
36+
top_p: float = 0.95
37+
38+
def params(self):
39+
keys = ["url", "model_name", "model_type", "api_key", "stop", "temperature", "top_k", "top_p"]
40+
return {
41+
k:v
42+
for k,v in self.__dict__.items()
43+
if k in keys}
44+
45+
def update_params(self, **kwargs):
46+
logger.debug(f"{kwargs}")
47+
# 更新属性
48+
for key, value in kwargs.items():
49+
logger.debug(f"{key}, {value}")
50+
setattr(self, key, value)
51+
52+
def _llm_type(self, *args):
53+
return ""
54+
55+
def predict(self, prompt: str, stop = None) -> str:
56+
return self._call(prompt, stop)
57+
58+
def _call(self, prompt: str,
59+
stop = None) -> str:
60+
"""_call
61+
"""
62+
return_str = ""
63+
stop = stop or self.stop
64+
65+
if self.model_type == "ollama":
66+
data = {
67+
"model": self.model_name,
68+
"prompt": prompt
69+
}
70+
r = requests.post(self.url, json=data, )
71+
return r.json()
72+
elif self.model_type == "openai":
73+
from muagent.llm_models.openai_model import getChatModelFromConfig
74+
llm_config = LLMConfig(
75+
model_name=self.model_name,
76+
model_engine="openai",
77+
api_key=self.api_key,
78+
api_base_url=self.url,
79+
temperature=self.temperature,
80+
stop=self.stop
81+
)
82+
model = getChatModelFromConfig(llm_config)
83+
return model.predict(prompt, stop=self.stop)
84+
elif self.model_type == "lingyiwangwu":
85+
from muagent.llm_models.openai_model import getChatModelFromConfig
86+
llm_config = LLMConfig(
87+
model_name=self.model_name,
88+
model_engine="lingyiwangwu",
89+
api_key=self.api_key,
90+
api_base_url=self.url,
91+
temperature=self.temperature,
92+
stop=self.stop
93+
)
94+
model = getChatModelFromConfig(llm_config)
95+
return model.predict(prompt, stop=self.stop)
96+
else:
97+
pass
98+
99+
return return_str
100+
101+
102+
class CustomEmbeddings(Embeddings):
103+
# ollama embeddings
104+
url = "http://localhost:11434/api/embeddings"
105+
#
106+
embedding_type = "ollama"
107+
model_name = ""
108+
api_key = ""
109+
110+
def params(self):
111+
return {
112+
"url": self.url, "model_name": self.model_name,
113+
"embedding_type": self.embedding_type, "api_key": self.api_key
114+
}
115+
116+
def update_params(self, **kwargs):
117+
logger.debug(f"{kwargs}")
118+
# 更新属性
119+
for key, value in kwargs.items():
120+
logger.debug(f"{key}, {value}")
121+
setattr(self, key, value)
122+
123+
def _get_sentence_emb(self, sentence: str) -> dict:
124+
"""
125+
调用句子向量提取服务
126+
"""
127+
if self.embedding_type == "ollama":
128+
data = {
129+
"model": self.model_name,
130+
"prompt": sentence
131+
}
132+
r = requests.post(self.url, json=data, )
133+
return r.json()
134+
elif self.embedding_type == "openai":
135+
from muagent.llm_models.get_embedding import get_embedding
136+
embed_config = EmbedConfig(
137+
embed_engine="openai",
138+
api_key=self.api_key,
139+
api_base_url=self.url,
140+
)
141+
text2vector_dict = get_embedding("openai", [sentence], embed_config=embed_config)
142+
return text2vector_dict[sentence]
143+
else:
144+
pass
145+
146+
return []
147+
148+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
149+
embeddings = []
150+
151+
def process_text(text):
152+
# print("分句:" + str(text) + "\n")
153+
emb_str = self._get_sentence_emb(text)
154+
# print("向量:" + str(emb_str) + "\n")
155+
return emb_str
156+
157+
with ThreadPoolExecutor() as executor:
158+
results = list(tqdm(executor.map(process_text, texts), total=len(texts), desc="Embedding documents"))
159+
160+
embeddings.extend(results)
161+
print("向量个数" + str(len(embeddings)))
162+
return embeddings
163+
164+
def embed_query(self, text: str) -> List[float]:
165+
"""Compute query embeddings using a HuggingFace transformer model.
166+
167+
Args:
168+
text: The text to embed.
169+
170+
Returns:
171+
Embeddings for the text.
172+
"""
173+
logger.info("提问query: " + str(text))
174+
embedding = self._get_sentence_emb(text)
175+
logger.info("提问向量:" + str(embedding))
176+
return embedding
177+
178+
179+
180+
cur_dir = os.path.dirname(__file__)
181+
print(cur_dir)
182+
183+
# 要打开的YAML文件路径
184+
file_path = 'ekg.yaml'
185+
186+
# 使用 'with' 语句确保文件正确关闭
187+
with open(os.path.join(cur_dir, file_path), 'r') as file:
188+
# 加载YAML文件内容
189+
config_data = yaml.safe_load(file)
190+
191+
192+
193+
# gb_config = GBConfig(
194+
# gb_type="GeaBaseHandler",
195+
# extra_kwargs={
196+
# 'metaserver_address': config_data["gbase_config"]['metaserver_address'],
197+
# 'project': config_data["gbase_config"]['project'],
198+
# 'city': config_data["gbase_config"]['city'],
199+
# 'lib_path': config_data["gbase_config"]['lib_path'],
200+
# }
201+
# )
202+
203+
204+
# gb_config = GBConfig(
205+
# gb_type="NebulaHandler",
206+
# extra_kwargs={}
207+
# )
208+
209+
210+
# 初始化 TbaseHandler 实例
211+
tb_config = TBConfig(
212+
tb_type="TbaseHandler",
213+
index_name="muagent_test",
214+
host=config_data["tbase_config"]["host"],
215+
port=config_data["tbase_config"]['port'],
216+
username=config_data["tbase_config"]['username'],
217+
password=config_data["tbase_config"]['password'],
218+
extra_kwargs={
219+
'host': config_data["tbase_config"]['host'],
220+
'port': config_data["tbase_config"]['port'],
221+
'username': config_data["tbase_config"]['username'] ,
222+
'password': config_data["tbase_config"]['password'],
223+
'definition_value': config_data["tbase_config"]['definition_value']
224+
}
225+
)
226+
227+
llm = CustomLLM()
228+
llm_config = LLMConfig(
229+
llm=llm
230+
)
231+
232+
233+
embeddings = CustomEmbeddings()
234+
embed_config = EmbedConfig(
235+
embed_model="default",
236+
langchain_embeddings=embeddings
237+
)
238+
239+
240+
# ekg_construct_service = EKGConstructService(
241+
# embed_config=embed_config,
242+
# llm_config=llm_config,
243+
# tb_config=tb_config,
244+
# gb_config=gb_config,
245+
# )
246+
247+
from muagent.httpapis.ekg_construct import create_api
248+
create_api(llm, embeddings)

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from loguru import logger
44
import json
55

6-
from gdbc2.geabase_client import GeaBaseClient, Node, Edge, MutateBatchOperation, GeaBaseUtil
7-
from gdbc2.geabase_env import GeaBaseEnv
6+
try:
7+
from gdbc2.geabase_client import GeaBaseClient, Node, Edge, MutateBatchOperation, GeaBaseUtil
8+
from gdbc2.geabase_env import GeaBaseEnv
9+
except:
10+
logger.error("ignore this sdk")
811

912
from .base_gb_handler import GBHandler
1013
from muagent.db_handler.utils import deduplicate_dict

muagent/httpapis/__init__.py

Whitespace-only changes.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .api import create_api
2+
3+
4+
__all__ = [
5+
"create_api"
6+
]

0 commit comments

Comments
 (0)