Skip to content

Commit aa4b0cd

Browse files
committed
[add test][test llm and embedding api]
1 parent cc8e44a commit aa4b0cd

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

examples/ekg_examples/start.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
import time, sys
2-
st = time.time()
32
import os
43
import yaml
54
import requests
65
from typing import List
76
from loguru import logger
8-
import tqdm
7+
from tqdm import tqdm
98
from concurrent.futures import ThreadPoolExecutor
109

11-
print(time.time()-st)
1210
from langchain.llms.base import LLM
1311
from langchain.embeddings.base import Embeddings
14-
print(time.time()-st)
1512
src_dir = os.path.join(
1613
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1714
)
18-
print(src_dir)
1915
sys.path.append(src_dir)
2016

17+
try:
18+
import os, sys
19+
src_dir = os.path.join(
20+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21+
)
22+
sys.path.append(src_dir)
23+
import test_config
24+
except Exception as e:
25+
# set your config
26+
logger.error(f"{e}")
27+
2128
from muagent.schemas.db import *
2229
from muagent.llm_models.llm_config import EmbedConfig, LLMConfig
2330
from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
@@ -30,7 +37,7 @@ class CustomLLM(LLM, BaseModel):
3037
model_name: str = "qwen2:1b"
3138
model_type: str = "ollama"
3239
api_key: str = ""
33-
stop: str = ""
40+
stop: str = None
3441
temperature: float = 0.3
3542
top_k: int = 50
3643
top_p: float = 0.95
@@ -43,10 +50,8 @@ def params(self):
4350
if k in keys}
4451

4552
def update_params(self, **kwargs):
46-
logger.debug(f"{kwargs}")
4753
# 更新属性
4854
for key, value in kwargs.items():
49-
logger.debug(f"{key}, {value}")
5055
setattr(self, key, value)
5156

5257
def _llm_type(self, *args):
@@ -114,10 +119,8 @@ def params(self):
114119
}
115120

116121
def update_params(self, **kwargs):
117-
logger.debug(f"{kwargs}")
118122
# 更新属性
119123
for key, value in kwargs.items():
120-
logger.debug(f"{key}, {value}")
121124
setattr(self, key, value)
122125

123126
def _get_sentence_emb(self, sentence: str) -> dict:
@@ -133,6 +136,8 @@ def _get_sentence_emb(self, sentence: str) -> dict:
133136
return r.json()
134137
elif self.embedding_type == "openai":
135138
from muagent.llm_models.get_embedding import get_embedding
139+
os.environ["OPENAI_API_KEY"] = self.api_key
140+
os.environ["API_BASE_URL"] = self.url
136141
embed_config = EmbedConfig(
137142
embed_engine="openai",
138143
api_key=self.api_key,

muagent/httpapis/ekg_construct/api.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict
33
import asyncio
44
import uvicorn
5+
from loguru import logger
56

67
from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
78
from muagent.schemas.apis.ekg_api_schema import *
@@ -18,6 +19,24 @@ def init_app(llm, embeddings):
1819
async def llm_params():
1920
return llm.params()
2021

22+
# ~/llm/params
23+
@app.post("/llm/generate", response_model=LLMResponse)
24+
async def llm_predict(request: LLMRequest):
25+
# 添加预测逻辑的代码
26+
errorMessage = "ok"
27+
successCode = True
28+
try:
29+
answer = llm.predict(request.text, request.stop)
30+
except Exception as e:
31+
errorMessage = str(e)
32+
successCode = False
33+
answer = "error"
34+
35+
return LLMResponse(
36+
successCode=successCode, errorMessage=errorMessage,
37+
answer=answer
38+
)
39+
2140
# ~/llm/params/update
2241
@app.post("/llm/params/update", response_model=EKGResponse)
2342
async def update_llm_params(kwargs: Dict):
@@ -55,6 +74,23 @@ async def update_embedding_params(kwargs: Dict):
5574
successCode=successCode, errorMessage=errorMessage,
5675
)
5776

77+
@app.post("/embeddings/generate", response_model=EmbeddingsResponse)
78+
async def embedding_predict(request: EmbeddingsRequest):
79+
# 添加预测逻辑的代码
80+
errorMessage = "ok"
81+
successCode = True
82+
try:
83+
embeddings_list = embeddings.embed_documents(request.texts)
84+
except Exception as e:
85+
logger.exception(e)
86+
errorMessage = str(e)
87+
successCode = False
88+
embeddings_list = []
89+
90+
return EmbeddingsResponse(
91+
successCode=successCode, errorMessage=errorMessage,
92+
embeddings=embeddings_list
93+
)
5894
# # ~/ekg/text2graph
5995
# @app.post("/ekg/text2graph", response_model=EKGGraphResponse)
6096
# async def text2graph(request: EKGT2GRequest):

muagent/schemas/apis/ekg_api_schema.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pydantic import BaseModel
2-
from typing import List, Dict
2+
from typing import List, Dict, Optional
33
from enum import Enum
44

55
from muagent.schemas.common import GNode, GEdge
@@ -12,6 +12,24 @@ class EKGResponse(BaseModel):
1212
errorMessage: str
1313

1414

15+
# embeddings
16+
class EmbeddingsResponse(EKGResponse):
17+
successCode: int
18+
errorMessage: str
19+
embeddings: List[List[float]]
20+
21+
class EmbeddingsRequest(BaseModel):
22+
texts: List[str]
23+
24+
class LLMRequest(BaseModel):
25+
text: str
26+
stop: Optional[str]
27+
28+
class LLMResponse(EKGResponse):
29+
successCode: int
30+
errorMessage: str
31+
answer: str
32+
1533
# text2graph
1634
class EKGT2GRequest(BaseModel):
1735
text: str
@@ -81,7 +99,7 @@ class LLMParamsResponse(BaseModel):
8199
model_name: str
82100
model_type: str
83101
api_key: str
84-
stop: str
102+
stop: Optional[str] = None
85103
temperature: float
86104
top_k: int
87105
top_p: float

0 commit comments

Comments
 (0)