1
+ import time , sys
2
+ st = time .time ()
3
+ import os
4
+ import yaml
5
+ import requests
6
+ from typing import List
7
+ from loguru import logger
8
+ import tqdm
9
+ from concurrent .futures import ThreadPoolExecutor
10
+
11
+ print (time .time ()- st )
12
+ from langchain .llms .base import LLM
13
+ from langchain .embeddings .base import Embeddings
14
+ print (time .time ()- st )
15
+ src_dir = os .path .join (
16
+ os .path .dirname (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
17
+ )
18
+ print (src_dir )
19
+ sys .path .append (src_dir )
20
+
21
+ from muagent .schemas .db import *
22
+ from muagent .llm_models .llm_config import EmbedConfig , LLMConfig
23
+ from muagent .service .ekg_construct .ekg_construct_base import EKGConstructService
24
+
25
+ from pydantic import BaseModel
26
+
27
+ # llm config
28
+ class CustomLLM (LLM , BaseModel ):
29
+ url : str = "http://localhost:11434/api/generate"
30
+ model_name : str = "qwen2:1b"
31
+ model_type : str = "ollama"
32
+ api_key : str = ""
33
+ stop : str = ""
34
+ temperature : float = 0.3
35
+ top_k : int = 50
36
+ top_p : float = 0.95
37
+
38
+ def params (self ):
39
+ keys = ["url" , "model_name" , "model_type" , "api_key" , "stop" , "temperature" , "top_k" , "top_p" ]
40
+ return {
41
+ k :v
42
+ for k ,v in self .__dict__ .items ()
43
+ if k in keys }
44
+
45
+ def update_params (self , ** kwargs ):
46
+ logger .debug (f"{ kwargs } " )
47
+ # 更新属性
48
+ for key , value in kwargs .items ():
49
+ logger .debug (f"{ key } , { value } " )
50
+ setattr (self , key , value )
51
+
52
+ def _llm_type (self , * args ):
53
+ return ""
54
+
55
+ def predict (self , prompt : str , stop = None ) -> str :
56
+ return self ._call (prompt , stop )
57
+
58
+ def _call (self , prompt : str ,
59
+ stop = None ) -> str :
60
+ """_call
61
+ """
62
+ return_str = ""
63
+ stop = stop or self .stop
64
+
65
+ if self .model_type == "ollama" :
66
+ data = {
67
+ "model" : self .model_name ,
68
+ "prompt" : prompt
69
+ }
70
+ r = requests .post (self .url , json = data , )
71
+ return r .json ()
72
+ elif self .model_type == "openai" :
73
+ from muagent .llm_models .openai_model import getChatModelFromConfig
74
+ llm_config = LLMConfig (
75
+ model_name = self .model_name ,
76
+ model_engine = "openai" ,
77
+ api_key = self .api_key ,
78
+ api_base_url = self .url ,
79
+ temperature = self .temperature ,
80
+ stop = self .stop
81
+ )
82
+ model = getChatModelFromConfig (llm_config )
83
+ return model .predict (prompt , stop = self .stop )
84
+ elif self .model_type == "lingyiwangwu" :
85
+ from muagent .llm_models .openai_model import getChatModelFromConfig
86
+ llm_config = LLMConfig (
87
+ model_name = self .model_name ,
88
+ model_engine = "lingyiwangwu" ,
89
+ api_key = self .api_key ,
90
+ api_base_url = self .url ,
91
+ temperature = self .temperature ,
92
+ stop = self .stop
93
+ )
94
+ model = getChatModelFromConfig (llm_config )
95
+ return model .predict (prompt , stop = self .stop )
96
+ else :
97
+ pass
98
+
99
+ return return_str
100
+
101
+
102
+ class CustomEmbeddings (Embeddings ):
103
+ # ollama embeddings
104
+ url = "http://localhost:11434/api/embeddings"
105
+ #
106
+ embedding_type = "ollama"
107
+ model_name = ""
108
+ api_key = ""
109
+
110
+ def params (self ):
111
+ return {
112
+ "url" : self .url , "model_name" : self .model_name ,
113
+ "embedding_type" : self .embedding_type , "api_key" : self .api_key
114
+ }
115
+
116
+ def update_params (self , ** kwargs ):
117
+ logger .debug (f"{ kwargs } " )
118
+ # 更新属性
119
+ for key , value in kwargs .items ():
120
+ logger .debug (f"{ key } , { value } " )
121
+ setattr (self , key , value )
122
+
123
+ def _get_sentence_emb (self , sentence : str ) -> dict :
124
+ """
125
+ 调用句子向量提取服务
126
+ """
127
+ if self .embedding_type == "ollama" :
128
+ data = {
129
+ "model" : self .model_name ,
130
+ "prompt" : sentence
131
+ }
132
+ r = requests .post (self .url , json = data , )
133
+ return r .json ()
134
+ elif self .embedding_type == "openai" :
135
+ from muagent .llm_models .get_embedding import get_embedding
136
+ embed_config = EmbedConfig (
137
+ embed_engine = "openai" ,
138
+ api_key = self .api_key ,
139
+ api_base_url = self .url ,
140
+ )
141
+ text2vector_dict = get_embedding ("openai" , [sentence ], embed_config = embed_config )
142
+ return text2vector_dict [sentence ]
143
+ else :
144
+ pass
145
+
146
+ return []
147
+
148
+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
149
+ embeddings = []
150
+
151
+ def process_text (text ):
152
+ # print("分句:" + str(text) + "\n")
153
+ emb_str = self ._get_sentence_emb (text )
154
+ # print("向量:" + str(emb_str) + "\n")
155
+ return emb_str
156
+
157
+ with ThreadPoolExecutor () as executor :
158
+ results = list (tqdm (executor .map (process_text , texts ), total = len (texts ), desc = "Embedding documents" ))
159
+
160
+ embeddings .extend (results )
161
+ print ("向量个数" + str (len (embeddings )))
162
+ return embeddings
163
+
164
+ def embed_query (self , text : str ) -> List [float ]:
165
+ """Compute query embeddings using a HuggingFace transformer model.
166
+
167
+ Args:
168
+ text: The text to embed.
169
+
170
+ Returns:
171
+ Embeddings for the text.
172
+ """
173
+ logger .info ("提问query: " + str (text ))
174
+ embedding = self ._get_sentence_emb (text )
175
+ logger .info ("提问向量:" + str (embedding ))
176
+ return embedding
177
+
178
+
179
+
180
+ cur_dir = os .path .dirname (__file__ )
181
+ print (cur_dir )
182
+
183
+ # 要打开的YAML文件路径
184
+ file_path = 'ekg.yaml'
185
+
186
+ # 使用 'with' 语句确保文件正确关闭
187
+ with open (os .path .join (cur_dir , file_path ), 'r' ) as file :
188
+ # 加载YAML文件内容
189
+ config_data = yaml .safe_load (file )
190
+
191
+
192
+
193
+ # gb_config = GBConfig(
194
+ # gb_type="GeaBaseHandler",
195
+ # extra_kwargs={
196
+ # 'metaserver_address': config_data["gbase_config"]['metaserver_address'],
197
+ # 'project': config_data["gbase_config"]['project'],
198
+ # 'city': config_data["gbase_config"]['city'],
199
+ # 'lib_path': config_data["gbase_config"]['lib_path'],
200
+ # }
201
+ # )
202
+
203
+
204
+ # gb_config = GBConfig(
205
+ # gb_type="NebulaHandler",
206
+ # extra_kwargs={}
207
+ # )
208
+
209
+
210
+ # 初始化 TbaseHandler 实例
211
+ tb_config = TBConfig (
212
+ tb_type = "TbaseHandler" ,
213
+ index_name = "muagent_test" ,
214
+ host = config_data ["tbase_config" ]["host" ],
215
+ port = config_data ["tbase_config" ]['port' ],
216
+ username = config_data ["tbase_config" ]['username' ],
217
+ password = config_data ["tbase_config" ]['password' ],
218
+ extra_kwargs = {
219
+ 'host' : config_data ["tbase_config" ]['host' ],
220
+ 'port' : config_data ["tbase_config" ]['port' ],
221
+ 'username' : config_data ["tbase_config" ]['username' ] ,
222
+ 'password' : config_data ["tbase_config" ]['password' ],
223
+ 'definition_value' : config_data ["tbase_config" ]['definition_value' ]
224
+ }
225
+ )
226
+
227
+ llm = CustomLLM ()
228
+ llm_config = LLMConfig (
229
+ llm = llm
230
+ )
231
+
232
+
233
+ embeddings = CustomEmbeddings ()
234
+ embed_config = EmbedConfig (
235
+ embed_model = "default" ,
236
+ langchain_embeddings = embeddings
237
+ )
238
+
239
+
240
+ # ekg_construct_service = EKGConstructService(
241
+ # embed_config=embed_config,
242
+ # llm_config=llm_config,
243
+ # tb_config=tb_config,
244
+ # gb_config=gb_config,
245
+ # )
246
+
247
+ from muagent .httpapis .ekg_construct import create_api
248
+ create_api (llm , embeddings )
0 commit comments