Skip to content

Commit 48ad5bb

Browse files
authored
Merge pull request #9 from codefuse-ai/modelcache_localDB_dev
version update
2 parents 2367a56 + 3dfb7e0 commit 48ad5bb

File tree

11 files changed

+487
-31
lines changed

11 files changed

+487
-31
lines changed

examples/flask/data_insert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def run():
77
url = 'http://127.0.0.1:5000/modelcache'
88
type = 'insert'
9-
scope = {"model": "CODEGPT-1109"}
9+
scope = {"model": "CODEGPT-1117"}
1010
chat_info = [{"query": [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}],
1111
"answer": "你好,我是智能助手,请问有什么能帮您!"}]
1212
data = {'type': type, 'scope': scope, 'chat_info': chat_info}

examples/flask/data_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def run():
77
url = 'http://127.0.0.1:5000/modelcache'
88
type = 'query'
9-
scope = {"model": "CODEGPT-1109"}
9+
scope = {"model": "CODEGPT-1117"}
1010
query = [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}]
1111
data = {'type': type, 'scope': scope, 'query': query}
1212

flask4modelcache.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from concurrent.futures import ThreadPoolExecutor
1515
from modelcache.utils.model_filter import model_blacklist_filter
1616
from modelcache.embedding import Data2VecAudio
17-
# from modelcache.maya_embedding_service.maya_embedding_service import get_cache_embedding_text2vec
1817

1918

2019
# 创建一个Flask实例
@@ -113,7 +112,8 @@ def user_backend():
113112
if response is None:
114113
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
115114
"answer": ''}
116-
elif response in ['adapt_query_exception']:
115+
# elif response in ['adapt_query_exception']:
116+
elif isinstance(response, str):
117117
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
118118
"hit_query": '', "answer": ''}
119119
else:
@@ -124,7 +124,7 @@ def user_backend():
124124
delta_time_log = round(time.time() - start_time, 2)
125125
future = executor.submit(save_query_info, result, model, query, delta_time_log)
126126
except Exception as e:
127-
result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,
127+
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
128128
"hit_query": '', "answer": ''}
129129
logging.info('result: {}'.format(result))
130130

@@ -138,19 +138,16 @@ def user_backend():
138138
chat_info=chat_info
139139
)
140140
except Exception as e:
141-
result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"}
141+
result = {"errorCode": 302, "errorDesc": str(e), "writeStatus": "exception"}
142142
return json.dumps(result, ensure_ascii=False)
143143

144-
if response in ['adapt_insert_exception']:
145-
result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
146-
elif response == 'success':
144+
if response == 'success':
147145
result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
148146
else:
149-
result = {"errorCode": 302, "errorDesc": response,
150-
"writeStatus": "exception"}
147+
result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
151148
return json.dumps(result, ensure_ascii=False)
152149
except Exception as e:
153-
result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"}
150+
result = {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
154151
return json.dumps(result, ensure_ascii=False)
155152

156153
if request_type == 'remove':
@@ -162,13 +159,11 @@ def user_backend():
162159
remove_type=remove_type,
163160
id_list=id_list
164161
)
165-
166162
if not isinstance(response, dict):
167163
result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
168164
return json.dumps(result)
169165

170166
state = response.get('status')
171-
172167
if state == 'success':
173168
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
174169
else:

flask4modelcache_demo.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
from flask import Flask, request
4+
import logging
5+
import json
6+
from modelcache import cache
7+
from modelcache.adapter import adapter
8+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
9+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
10+
from modelcache.processor.pre import query_multi_splicing
11+
from modelcache.processor.pre import insert_multi_splicing
12+
from concurrent.futures import ThreadPoolExecutor
13+
from modelcache.utils.model_filter import model_blacklist_filter
14+
from modelcache.embedding import Data2VecAudio
15+
16+
# 创建一个Flask实例
17+
app = Flask(__name__)
18+
19+
20+
def response_text(cache_resp):
21+
return cache_resp['data']
22+
23+
24+
def save_query_info(result, model, query, delta_time_log):
25+
cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False),
26+
delta_time=delta_time_log)
27+
28+
29+
def response_hitquery(cache_resp):
30+
return cache_resp['hitQuery']
31+
32+
33+
data2vec = Data2VecAudio()
34+
data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
35+
36+
37+
cache.init(
38+
embedding_func=data2vec.to_embeddings,
39+
data_manager=data_manager,
40+
similarity_evaluation=SearchDistanceEvaluation(),
41+
query_pre_embedding_func=query_multi_splicing,
42+
insert_pre_embedding_func=insert_multi_splicing,
43+
)
44+
45+
# cache.set_openai_key()
46+
global executor
47+
executor = ThreadPoolExecutor(max_workers=6)
48+
49+
50+
@app.route('/welcome')
51+
def first_flask(): # 视图函数
52+
return 'hello, modelcache!'
53+
54+
55+
@app.route('/modelcache', methods=['GET', 'POST'])
56+
def user_backend():
57+
try:
58+
if request.method == 'POST':
59+
request_data = request.json
60+
elif request.method == 'GET':
61+
request_data = request.args
62+
param_dict = json.loads(request_data)
63+
except Exception as e:
64+
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
65+
"answer": ''}
66+
cache.data_manager.save_query_resp(result, model='', query='', delta_time=0)
67+
return json.dumps(result)
68+
69+
# param parsing
70+
try:
71+
request_type = param_dict.get("type")
72+
scope = param_dict.get("scope")
73+
if scope is not None:
74+
model = scope.get('model')
75+
model = model.replace('-', '_')
76+
model = model.replace('.', '_')
77+
query = param_dict.get("query")
78+
chat_info = param_dict.get("chat_info")
79+
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
80+
result = {"errorCode": 102,
81+
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
82+
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
83+
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
84+
return json.dumps(result)
85+
except Exception as e:
86+
result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
87+
"answer": ''}
88+
return json.dumps(result)
89+
90+
# model filter
91+
filter_resp = model_blacklist_filter(model, request_type)
92+
if isinstance(filter_resp, dict):
93+
return json.dumps(filter_resp)
94+
95+
if request_type == 'query':
96+
try:
97+
start_time = time.time()
98+
response = adapter.ChatCompletion.create_query(
99+
scope={"model": model},
100+
query=query
101+
)
102+
delta_time = '{}s'.format(round(time.time() - start_time, 2))
103+
if response is None:
104+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
105+
"answer": ''}
106+
elif response in ['adapt_query_exception']:
107+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
108+
"hit_query": '', "answer": ''}
109+
else:
110+
answer = response_text(response)
111+
hit_query = response_hitquery(response)
112+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
113+
"hit_query": hit_query, "answer": answer}
114+
delta_time_log = round(time.time() - start_time, 2)
115+
future = executor.submit(save_query_info, result, model, query, delta_time_log)
116+
except Exception as e:
117+
result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,
118+
"hit_query": '', "answer": ''}
119+
logging.info('result: {}'.format(result))
120+
121+
return json.dumps(result, ensure_ascii=False)
122+
123+
if request_type == 'insert':
124+
try:
125+
try:
126+
response = adapter.ChatCompletion.create_insert(
127+
model=model,
128+
chat_info=chat_info
129+
)
130+
except Exception as e:
131+
result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"}
132+
return json.dumps(result, ensure_ascii=False)
133+
134+
if response in ['adapt_insert_exception']:
135+
result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
136+
elif response == 'success':
137+
result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
138+
else:
139+
result = {"errorCode": 302, "errorDesc": response,
140+
"writeStatus": "exception"}
141+
return json.dumps(result, ensure_ascii=False)
142+
except Exception as e:
143+
result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"}
144+
return json.dumps(result, ensure_ascii=False)
145+
146+
if request_type == 'remove':
147+
remove_type = param_dict.get("remove_type")
148+
id_list = param_dict.get("id_list", [])
149+
150+
response = adapter.ChatCompletion.create_remove(
151+
model=model,
152+
remove_type=remove_type,
153+
id_list=id_list
154+
)
155+
156+
if not isinstance(response, dict):
157+
result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
158+
return json.dumps(result)
159+
160+
state = response.get('status')
161+
if state == 'success':
162+
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
163+
else:
164+
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
165+
return json.dumps(result)
166+
167+
168+
if __name__ == '__main__':
169+
app.run(host='0.0.0.0', port=5000, debug=True)

modelcache/adapter/adapter.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def cache_data_convert(cache_data, cache_query):
2121
**kwargs
2222
)
2323
except Exception as e:
24-
return 'adapt_query_exception'
25-
24+
return str(e)
2625

2726
@classmethod
2827
def create_insert(cls, *args, **kwargs):
@@ -32,9 +31,7 @@ def create_insert(cls, *args, **kwargs):
3231
**kwargs
3332
)
3433
except Exception as e:
35-
logging.info('adapt_insert_e: {}'.format(e))
36-
return 'adapt_insert_exception'
37-
34+
return str(e)
3835

3936
@classmethod
4037
def create_remove(cls, *args, **kwargs):
@@ -45,7 +42,7 @@ def create_remove(cls, *args, **kwargs):
4542
)
4643
except Exception as e:
4744
logging.info('adapt_remove_e: {}'.format(e))
48-
return 'adapt_remove_exception'
45+
return str(e)
4946

5047

5148
def construct_resp_from_cache(return_message, return_query):

modelcache/adapter/adapter_remove.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from modelcache import cache
3-
from modelcache.utils.error import NotInitError
3+
from modelcache.utils.error import NotInitError, RemoveError
44

55

66
def adapt_remove(*args, **kwargs):
@@ -17,11 +17,10 @@ def adapt_remove(*args, **kwargs):
1717
if remove_type == 'delete_by_id':
1818
id_list = kwargs.pop("id_list", [])
1919
resp = chat_cache.data_manager.delete(id_list, model=model)
20-
2120
elif remove_type == 'truncate_by_model':
2221
resp = chat_cache.data_manager.truncate(model)
23-
2422
else:
25-
resp = "remove_type_error"
23+
# resp = "remove_type_error"
24+
raise RemoveError()
2625
return resp
2726

modelcache/manager/scalar_data/manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from modelcache.utils import import_sql_client
33
from modelcache.utils.error import NotFoundError
44

5+
SQL_URL = {"sqlite": "./sqlite.db"}
6+
57

68
class CacheBase:
79
"""
@@ -15,12 +17,16 @@ def __init__(self):
1517

1618
@staticmethod
1719
def get(name, **kwargs):
18-
if name in ["sqlite", "mysql"]:
20+
21+
if name in ["mysql", "oceanbase"]:
1922
from modelcache.manager.scalar_data.sql_storage import SQLStorage
2023
config = kwargs.get("config")
21-
# db_name = kwargs.get("db_name")
2224
import_sql_client(name)
2325
cache_base = SQLStorage(db_type=name, config=config)
26+
elif name == 'sqlite':
27+
from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage
28+
sql_url = kwargs.get("sql_url", SQL_URL[name])
29+
cache_base = SQLStorage(db_type=name, url=sql_url)
2430
else:
2531
raise NotFoundError("cache store", name)
2632
return cache_base

modelcache/manager/scalar_data/sql_storage.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ class SQLStorage(CacheStorage):
1414
def __init__(
1515
self,
1616
db_type: str = "mysql",
17-
config=None,
17+
config=None
1818
):
1919

2020
self.host = config.get('mysql', 'host')
2121
self.port = int(config.get('mysql', 'port'))
2222
self.username = config.get('mysql', 'username')
2323
self.password = config.get('mysql', 'password')
2424
self.database = config.get('mysql', 'database')
25-
2625
self.pool = PooledDB(
2726
creator=pymysql,
2827
host=self.host,
@@ -78,7 +77,7 @@ def insert_query_resp(self, query_resp, **kwargs):
7877
if isinstance(hit_query, list):
7978
hit_query = json.dumps(hit_query, ensure_ascii=False)
8079

81-
table_name = "cache_query_log_info"
80+
table_name = "modelcache_query_log"
8281
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
8382
conn = self.pool.connection()
8483
try:
@@ -90,7 +89,6 @@ def insert_query_resp(self, query_resp, **kwargs):
9089
finally:
9190
# 关闭连接,将连接返回给连接池
9291
conn.close()
93-
return id
9492

9593
def get_data_by_id(self, key: int):
9694
table_name = "cache_codegpt_answer"

0 commit comments

Comments
 (0)