Skip to content

Commit d13d18b

Browse files
committed
Adding an SQLite database; incorporating a local demo Flask service.
1 parent ec17493 commit d13d18b

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed

flask4modelcache_demo.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
from flask import Flask, request
4+
import logging
5+
import json
6+
from modelcache import cache
7+
from modelcache.adapter import adapter
8+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
9+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
10+
from modelcache.processor.pre import query_multi_splicing
11+
from modelcache.processor.pre import insert_multi_splicing
12+
from concurrent.futures import ThreadPoolExecutor
13+
from modelcache.utils.model_filter import model_blacklist_filter
14+
from modelcache.embedding import Data2VecAudio
15+
16+
# 创建一个Flask实例
17+
app = Flask(__name__)
18+
19+
20+
def response_text(cache_resp):
21+
return cache_resp['data']
22+
23+
24+
def save_query_info(result, model, query, delta_time_log):
25+
cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False),
26+
delta_time=delta_time_log)
27+
28+
29+
def response_hitquery(cache_resp):
30+
return cache_resp['hitQuery']
31+
32+
33+
data2vec = Data2VecAudio()
34+
data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
35+
36+
37+
cache.init(
38+
embedding_func=data2vec.to_embeddings,
39+
data_manager=data_manager,
40+
similarity_evaluation=SearchDistanceEvaluation(),
41+
query_pre_embedding_func=query_multi_splicing,
42+
insert_pre_embedding_func=insert_multi_splicing,
43+
)
44+
45+
# cache.set_openai_key()
46+
global executor
47+
executor = ThreadPoolExecutor(max_workers=6)
48+
49+
50+
@app.route('/welcome')
51+
def first_flask(): # 视图函数
52+
return 'hello, modelcache!'
53+
54+
55+
@app.route('/modelcache', methods=['GET', 'POST'])
56+
def user_backend():
57+
try:
58+
if request.method == 'POST':
59+
request_data = request.json
60+
elif request.method == 'GET':
61+
request_data = request.args
62+
param_dict = json.loads(request_data)
63+
except Exception as e:
64+
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
65+
"answer": ''}
66+
cache.data_manager.save_query_resp(result, model='', query='', delta_time=0)
67+
return json.dumps(result)
68+
69+
# param parsing
70+
try:
71+
request_type = param_dict.get("type")
72+
scope = param_dict.get("scope")
73+
if scope is not None:
74+
model = scope.get('model')
75+
model = model.replace('-', '_')
76+
model = model.replace('.', '_')
77+
query = param_dict.get("query")
78+
chat_info = param_dict.get("chat_info")
79+
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
80+
result = {"errorCode": 102,
81+
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
82+
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
83+
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
84+
return json.dumps(result)
85+
except Exception as e:
86+
result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
87+
"answer": ''}
88+
return json.dumps(result)
89+
90+
# model filter
91+
filter_resp = model_blacklist_filter(model, request_type)
92+
if isinstance(filter_resp, dict):
93+
return json.dumps(filter_resp)
94+
95+
if request_type == 'query':
96+
try:
97+
start_time = time.time()
98+
response = adapter.ChatCompletion.create_query(
99+
scope={"model": model},
100+
query=query
101+
)
102+
delta_time = '{}s'.format(round(time.time() - start_time, 2))
103+
if response is None:
104+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
105+
"answer": ''}
106+
elif response in ['adapt_query_exception']:
107+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
108+
"hit_query": '', "answer": ''}
109+
else:
110+
answer = response_text(response)
111+
hit_query = response_hitquery(response)
112+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
113+
"hit_query": hit_query, "answer": answer}
114+
delta_time_log = round(time.time() - start_time, 2)
115+
future = executor.submit(save_query_info, result, model, query, delta_time_log)
116+
except Exception as e:
117+
result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,
118+
"hit_query": '', "answer": ''}
119+
logging.info('result: {}'.format(result))
120+
121+
return json.dumps(result, ensure_ascii=False)
122+
123+
if request_type == 'insert':
124+
try:
125+
try:
126+
response = adapter.ChatCompletion.create_insert(
127+
model=model,
128+
chat_info=chat_info
129+
)
130+
except Exception as e:
131+
result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"}
132+
return json.dumps(result, ensure_ascii=False)
133+
134+
if response in ['adapt_insert_exception']:
135+
result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
136+
elif response == 'success':
137+
result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
138+
else:
139+
result = {"errorCode": 302, "errorDesc": response,
140+
"writeStatus": "exception"}
141+
return json.dumps(result, ensure_ascii=False)
142+
except Exception as e:
143+
result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"}
144+
return json.dumps(result, ensure_ascii=False)
145+
146+
if request_type == 'remove':
147+
remove_type = param_dict.get("remove_type")
148+
id_list = param_dict.get("id_list", [])
149+
150+
response = adapter.ChatCompletion.create_remove(
151+
model=model,
152+
remove_type=remove_type,
153+
id_list=id_list
154+
)
155+
156+
if not isinstance(response, dict):
157+
result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
158+
return json.dumps(result)
159+
160+
state = response.get('status')
161+
if state == 'success':
162+
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
163+
else:
164+
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
165+
return json.dumps(result)
166+
167+
168+
if __name__ == '__main__':
169+
app.run(host='0.0.0.0', port=5000, debug=True)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import time
4+
5+
import pymysql
6+
import json
7+
import base64
8+
from typing import List
9+
from modelcache.manager.scalar_data.base import CacheStorage, CacheData
10+
from DBUtils.PooledDB import PooledDB
11+
12+
13+
class SQLStorage(CacheStorage):
14+
def __init__(
15+
self,
16+
db_type: str = "mysql",
17+
config=None,
18+
url="sqlite:///./sqlite.db"
19+
):
20+
if db_type in ["mysql", "oceanbase"]:
21+
self.host = config.get('mysql', 'host')
22+
self.port = int(config.get('mysql', 'port'))
23+
self.username = config.get('mysql', 'username')
24+
self.password = config.get('mysql', 'password')
25+
self.database = config.get('mysql', 'database')
26+
27+
self.pool = PooledDB(
28+
creator=pymysql,
29+
host=self.host,
30+
user=self.username,
31+
password=self.password,
32+
port=self.port,
33+
database=self.database
34+
)
35+
elif db_type == 'sqlite':
36+
self._url = url
37+
38+
def create(self):
39+
pass
40+
41+
def _insert(self, data: List):
42+
answer = data[0]
43+
question = data[1]
44+
embedding_data = data[2]
45+
model = data[3]
46+
answer_type = 0
47+
embedding_data = embedding_data.tobytes()
48+
49+
table_name = "cache_codegpt_answer"
50+
insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data) VALUES (%s, %s, %s, %s, _binary%s)".format(table_name)
51+
52+
conn = self.pool.connection()
53+
try:
54+
with conn.cursor() as cursor:
55+
# 执行插入数据操作
56+
values = (question, answer, answer_type, model, embedding_data)
57+
cursor.execute(insert_sql, values)
58+
conn.commit()
59+
id = cursor.lastrowid
60+
finally:
61+
# 关闭连接,将连接返回给连接池
62+
conn.close()
63+
return id
64+
65+
def batch_insert(self, all_data: List[CacheData]):
66+
ids = []
67+
for data in all_data:
68+
ids.append(self._insert(data))
69+
return ids
70+
71+
def insert_query_resp(self, query_resp, **kwargs):
72+
error_code = query_resp.get('errorCode')
73+
error_desc = query_resp.get('errorDesc')
74+
cache_hit = query_resp.get('cacheHit')
75+
model = kwargs.get('model')
76+
query = kwargs.get('query')
77+
delta_time = kwargs.get('delta_time')
78+
hit_query = query_resp.get('hit_query')
79+
answer = query_resp.get('answer')
80+
81+
if isinstance(hit_query, list):
82+
hit_query = json.dumps(hit_query, ensure_ascii=False)
83+
84+
table_name = "cache_query_log_info"
85+
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
86+
conn = self.pool.connection()
87+
try:
88+
with conn.cursor() as cursor:
89+
# 执行插入数据操作
90+
values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
91+
cursor.execute(insert_sql, values)
92+
conn.commit()
93+
finally:
94+
# 关闭连接,将连接返回给连接池
95+
conn.close()
96+
return id
97+
98+
def get_data_by_id(self, key: int):
99+
table_name = "cache_codegpt_answer"
100+
query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
101+
conn_start = time.time()
102+
conn = self.pool.connection()
103+
104+
search_start = time.time()
105+
try:
106+
with conn.cursor() as cursor:
107+
# 执行数据库操作
108+
cursor.execute(query_sql)
109+
resp = cursor.fetchone()
110+
finally:
111+
# 关闭连接,将连接返回给连接池
112+
conn.close()
113+
114+
if resp is not None and len(resp) == 4:
115+
return resp
116+
else:
117+
return None
118+
119+
def update_hit_count_by_id(self, primary_id: int):
120+
table_name = "cache_codegpt_answer"
121+
update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
122+
conn = self.pool.connection()
123+
124+
# 使用连接执行更新数据操作
125+
try:
126+
with conn.cursor() as cursor:
127+
# 执行更新数据操作
128+
cursor.execute(update_sql)
129+
conn.commit()
130+
finally:
131+
# 关闭连接,将连接返回给连接池
132+
conn.close()
133+
134+
def get_ids(self, deleted=True):
135+
pass
136+
137+
def mark_deleted(self, keys):
138+
table_name = "cache_codegpt_answer"
139+
delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
140+
141+
# 从连接池中获取连接
142+
conn = self.pool.connection()
143+
try:
144+
with conn.cursor() as cursor:
145+
# 执行删除数据操作
146+
cursor.execute(delete_sql)
147+
delete_count = cursor.rowcount
148+
conn.commit()
149+
finally:
150+
# 关闭连接,将连接返回给连接池
151+
conn.close()
152+
return delete_count
153+
154+
def model_deleted(self, model_name):
155+
table_name = "cache_codegpt_answer"
156+
delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name)
157+
conn = self.pool.connection()
158+
# 使用连接执行删除数据操作
159+
try:
160+
with conn.cursor() as cursor:
161+
# 执行删除数据操作
162+
resp = cursor.execute(delete_sql)
163+
conn.commit()
164+
finally:
165+
# 关闭连接,将连接返回给连接池
166+
conn.close()
167+
return resp
168+
169+
def clear_deleted_data(self):
170+
pass
171+
172+
def count(self, state: int = 0, is_all: bool = False):
173+
pass
174+
175+
def close(self):
176+
pass
177+
178+
def count_answers(self):
179+
pass

0 commit comments

Comments
 (0)