Skip to content

Commit 5127d77

Browse files
Added efficiency fixes to Milvus and MySQL
Milvus: started loading all the collections into the memory instead of loading the collection into memory anew every time MySQL: Added support for bulk insert to MySQL. Had to change the id generation from AUTO INCREMENT to uuids generated before insertion. This is because MySQL doesn't support bulk insert with returning all ids generated by auto increment. Co-authored-by: olgaoznovich <[email protected]> Co-authored-by: Yuval-Roth <[email protected]>
1 parent c4f9253 commit 5127d77

File tree

6 files changed

+132
-56
lines changed

6 files changed

+132
-56
lines changed

data/mysql/init/init.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CREATE DATABASE IF NOT EXISTS `modelcache`;
33
USE `modelcache`;
44

55
CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` (
6-
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
6+
`id` CHAR(36) comment '主键',
77
`gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
88
`gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
99
`question` text NOT NULL comment 'question',

docker-compose.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: "modelcache"
22
services:
33
mysql:
4-
image: mysql:9.3.0
4+
image: mysql:8.0.23
55
container_name: mysql
66
environment:
77
MYSQL_ROOT_PASSWORD: 'root'

modelcache/adapter/adapter_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
def adapt_query(cache_data_convert, *args, **kwargs):
1313
chat_cache = kwargs.pop("cache_obj", cache)
14-
scope = kwargs.pop("scope", None)
14+
scope = kwargs.pop("scope")
1515
model = scope['model']
1616
if not chat_cache.has_init:
1717
raise NotInitError()

modelcache/manager/scalar_data/sql_storage.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import os
33
import time
4+
import uuid
45

56
import pymysql
67
import json
@@ -42,26 +43,59 @@ def _insert(self, data: List):
4243
answer_type = 0
4344
embedding_data = embedding_data.tobytes()
4445
is_deleted = 0
46+
_id = str(uuid.uuid4())
4547

4648
table_name = "modelcache_llm_answer"
47-
insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data, is_deleted) VALUES (%s, %s, %s, %s, _binary%s, %s)".format(table_name)
49+
insert_sql = f"""
50+
INSERT INTO {table_name}
51+
(id, question, answer, answer_type, model, embedding_data, is_deleted)
52+
VALUES (%s, %s, %s, %s, %s, _binary%s, %s)
53+
"""
4854
conn = self.pool.connection()
4955
try:
5056
with conn.cursor() as cursor:
5157
# 执行插入数据操作
52-
values = (question, answer, answer_type, model, embedding_data, is_deleted)
58+
values = (_id, question, answer, answer_type, model, embedding_data, is_deleted)
5359
cursor.execute(insert_sql, values)
5460
conn.commit()
55-
id = cursor.lastrowid
5661
finally:
5762
# 关闭连接,将连接返回给连接池
5863
conn.close()
59-
return id
64+
return _id
6065

61-
def batch_insert(self, all_data: List[CacheData]):
66+
def batch_insert(self, all_data: List[List]):
67+
table_name = "modelcache_llm_answer"
68+
insert_sql = f"""
69+
INSERT INTO {table_name}
70+
(id, question, answer, answer_type, model, embedding_data, is_deleted)
71+
VALUES (%s, %s, %s, %s, %s, %s, %s)
72+
"""
73+
74+
values_list = []
6275
ids = []
76+
6377
for data in all_data:
64-
ids.append(self._insert(data))
78+
answer = data[0]
79+
question = data[1]
80+
embedding_data = data[2].tobytes()
81+
model = data[3]
82+
answer_type = 0
83+
is_deleted = 0
84+
_id = str(uuid.uuid4())
85+
ids.append(_id)
86+
87+
values_list.append((
88+
_id, question, answer, answer_type, model, embedding_data, is_deleted
89+
))
90+
91+
conn = self.pool.connection()
92+
try:
93+
with conn.cursor() as cursor:
94+
cursor.executemany(insert_sql, values_list)
95+
conn.commit()
96+
finally:
97+
conn.close()
98+
6599
return ids
66100

67101
def insert_query_resp(self, query_resp, **kwargs):
@@ -78,7 +112,11 @@ def insert_query_resp(self, query_resp, **kwargs):
78112
hit_query = json.dumps(hit_query, ensure_ascii=False)
79113

80114
table_name = "modelcache_query_log"
81-
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
115+
insert_sql = f"""
116+
INSERT INTO {table_name}
117+
(error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
118+
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
119+
"""
82120
conn = self.pool.connection()
83121
try:
84122
with conn.cursor() as cursor:
@@ -92,15 +130,16 @@ def insert_query_resp(self, query_resp, **kwargs):
92130

93131
def get_data_by_id(self, key: int):
94132
table_name = "modelcache_llm_answer"
95-
query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
96-
conn_start = time.time()
133+
query_sql = f"""
134+
SELECT question, answer, embedding_data, model
135+
FROM {table_name}
136+
WHERE id = %s
137+
"""
97138
conn = self.pool.connection()
98-
99-
search_start = time.time()
100139
try:
101140
with conn.cursor() as cursor:
102141
# 执行数据库操作
103-
cursor.execute(query_sql)
142+
cursor.execute(query_sql, (key,))
104143
resp = cursor.fetchone()
105144
finally:
106145
# 关闭连接,将连接返回给连接池
@@ -113,14 +152,18 @@ def get_data_by_id(self, key: int):
113152

114153
def update_hit_count_by_id(self, primary_id: int):
115154
table_name = "modelcache_llm_answer"
116-
update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
155+
update_sql = f"""
156+
UPDATE {table_name}
157+
SET hit_count = hit_count+1
158+
WHERE id = %s
159+
"""
117160
conn = self.pool.connection()
118161

119162
# 使用连接执行更新数据操作
120163
try:
121164
with conn.cursor() as cursor:
122165
# 执行更新数据操作
123-
cursor.execute(update_sql)
166+
cursor.execute(update_sql,(primary_id,))
124167
conn.commit()
125168
finally:
126169
# 关闭连接,将连接返回给连接池
@@ -129,12 +172,16 @@ def update_hit_count_by_id(self, primary_id: int):
129172
def get_ids(self, deleted=True):
130173
table_name = "modelcache_llm_answer"
131174
state = 1 if deleted else 0
132-
query_sql = "Select id FROM {} WHERE is_deleted = {}".format(table_name, state)
175+
query_sql = f"""
176+
SELECT id
177+
FROM {table_name}
178+
WHERE is_deleted = %s
179+
"""
133180

134181
conn = self.pool.connection()
135182
try:
136183
with conn.cursor() as cursor:
137-
cursor.execute(query_sql)
184+
cursor.execute(query_sql, (state,))
138185
ids = [row[0] for row in cursor.fetchall()]
139186
finally:
140187
conn.close()
@@ -143,37 +190,45 @@ def get_ids(self, deleted=True):
143190

144191
def mark_deleted(self, keys):
145192
table_name = "modelcache_llm_answer"
146-
mark_sql = " update {} set is_deleted=1 WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
193+
placeholders = ",".join(["%s"] * len(keys))
194+
mark_sql = f"""
195+
UPDATE {table_name}
196+
SET is_deleted=1
197+
WHERE id in ({placeholders})
198+
"""
147199

148-
# 从连接池中获取连接
149200
conn = self.pool.connection()
150201
try:
151202
with conn.cursor() as cursor:
152-
# 执行删除数据操作
153-
cursor.execute(mark_sql)
203+
cursor.execute(mark_sql, keys)
154204
delete_count = cursor.rowcount
155205
conn.commit()
156206
finally:
157-
# 关闭连接,将连接返回给连接池
158207
conn.close()
159208
return delete_count
160209

161210
def model_deleted(self, model_name):
162211
table_name = "modelcache_llm_answer"
163-
delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name)
212+
delete_sql = f"""
213+
Delete from {table_name}
214+
WHERE model = %s
215+
"""
164216

165217
table_log_name = "modelcache_query_log"
166-
delete_log_sql = "Delete from {} WHERE model='{}'".format(table_log_name, model_name)
218+
delete_log_sql = f"""
219+
Delete from {table_log_name}
220+
WHERE model = %s
221+
"""
167222

168223
conn = self.pool.connection()
169224
# 使用连接执行删除数据操作
170225
try:
171226
with conn.cursor() as cursor:
172227
# 执行删除数据操作
173-
resp = cursor.execute(delete_sql)
228+
resp = cursor.execute(delete_sql, (model_name,))
174229
conn.commit()
175230
# 执行删除该模型对应日志操作 resp_log行数不返回
176-
resp_log = cursor.execute(delete_log_sql)
231+
resp_log = cursor.execute(delete_log_sql, (model_name,))
177232
conn.commit() # 分别提交事务
178233
finally:
179234
# 关闭连接,将连接返回给连接池
@@ -182,7 +237,10 @@ def model_deleted(self, model_name):
182237

183238
def clear_deleted_data(self):
184239
table_name = "modelcache_llm_answer"
185-
delete_sql = "DELETE FROM {} WHERE is_deleted = 1".format(table_name)
240+
delete_sql = f"""
241+
DELETE FROM {table_name}
242+
WHERE is_deleted = 1
243+
"""
186244

187245
conn = self.pool.connection()
188246
try:
@@ -197,10 +255,15 @@ def clear_deleted_data(self):
197255

198256
def count(self, state: int = 0, is_all: bool = False):
199257
table_name = "modelcache_llm_answer"
258+
259+
# we're not using prepared statements here, so we need to ensure state is an integer
260+
if not isinstance(state, int):
261+
raise ValueError("'state' must be an integer.")
262+
200263
if is_all:
201-
count_sql = "SELECT COUNT(*) FROM {}".format(table_name)
264+
count_sql = f"SELECT COUNT(*) FROM {table_name}"
202265
else:
203-
count_sql = "SELECT COUNT(*) FROM {} WHERE is_deleted = {}".format(table_name,state)
266+
count_sql = f"SELECT COUNT(*) FROM {table_name} WHERE is_deleted = {state}"
204267

205268
conn = self.pool.connection()
206269
try:

modelcache/manager/vector_data/milvus.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __init__(
6666
self.search_params = (
6767
search_params or self.SEARCH_PARAM[self.index_params["index_type"]]
6868
)
69+
self.collections = dict()
70+
6971

7072
def _connect(self, host, port, user, password, secure):
7173
try:
@@ -87,12 +89,14 @@ def _connect(self, host, port, user, password, secure):
8789
timeout=10
8890
)
8991

92+
9093
def _create_collection(self, collection_name):
9194
if not utility.has_collection(collection_name, using=self.alias):
9295
schema = [
9396
FieldSchema(
9497
name="id",
95-
dtype=DataType.INT64,
98+
dtype=DataType.VARCHAR,
99+
max_length=36,
96100
is_primary=True,
97101
auto_id=False,
98102
),
@@ -101,67 +105,71 @@ def _create_collection(self, collection_name):
101105
),
102106
]
103107
schema = CollectionSchema(schema)
104-
self.col = Collection(
108+
109+
new_collection = Collection(
105110
collection_name,
106111
schema=schema,
107112
consistency_level="Session",
108113
using=self.alias,
109114
)
110115
else:
111116
modelcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
112-
self.col = Collection(
117+
new_collection = Collection(
113118
collection_name, consistency_level="Session", using=self.alias
114119
)
115120

116-
if len(self.col.indexes) == 0:
121+
self.collections[collection_name] = new_collection
122+
123+
if len(new_collection.indexes) == 0:
117124
try:
118125
modelcache_log.info("Attempting creation of Milvus index.")
119-
self.col.create_index("embedding", index_params=self.index_params)
126+
new_collection.create_index("embedding", index_params=self.index_params)
120127
modelcache_log.info("Creation of Milvus index successful.")
121128
except MilvusException as e:
122129
modelcache_log.warning("Error with building index: %s, and attempting creation of default index.", e)
123130
i_p = {"metric_type": "L2", "index_type": "AUTOINDEX", "params": {}}
124-
self.col.create_index("embedding", index_params=i_p)
131+
new_collection.create_index("embedding", index_params=i_p)
125132
self.index_params = i_p
126133
else:
127-
self.index_params = self.col.indexes[0].to_dict()["index_param"]
134+
self.index_params = new_collection.indexes[0].to_dict()["index_param"]
135+
136+
new_collection.load()
128137

129-
self.col.load()
130138

131139
def _get_collection(self, collection_name):
132-
self.col = Collection(
133-
collection_name, consistency_level="Session", using=self.alias
134-
)
135-
self.col.load()
140+
if collection_name not in self.collections:
141+
self._create_collection(collection_name)
142+
return self.collections[collection_name]
136143

137144
def mul_add(self, datas: List[VectorData], model=None):
138145
collection_name_model = self.collection_name + '_' + model
139-
self._create_collection(collection_name_model)
140-
146+
col = self._get_collection(collection_name_model)
141147
data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
142148
np_data = np.array(data_array).astype("float32")
143149
entities = [id_array, np_data]
144-
self.col.insert(entities)
150+
col.insert(entities)
151+
145152

146153
def search(self, data: np.ndarray, top_k: int = -1, model=None):
147154
if top_k == -1:
148155
top_k = self.top_k
149156
collection_name_model = self.collection_name + '_' + model
150-
self._create_collection(collection_name_model)
151-
search_result = self.col.search(
157+
col = self._get_collection(collection_name_model)
158+
search_result = col.search(
152159
data=data.reshape(1, -1).tolist(),
153160
anns_field="embedding",
154161
param=self.search_params,
155162
limit=top_k,
156163
)
157164
return list(zip(search_result[0].distances, search_result[0].ids))
158165

166+
159167
def delete(self, ids, model=None):
160168
collection_name_model = self.collection_name + '_' + model
161-
self._get_collection(collection_name_model)
169+
col = self._get_collection(collection_name_model)
162170

163171
del_ids = ",".join([str(x) for x in ids])
164-
resp = self.col.delete(f"id in [{del_ids}]")
172+
resp = col.delete(f"id in [{del_ids}]")
165173
delete_count = resp.delete_count
166174
return delete_count
167175

@@ -178,10 +186,12 @@ def rebuild_col(self, model):
178186
logging.info('create_collection: {}'.format(e))
179187

180188
def rebuild(self, ids=None): # pylint: disable=unused-argument
181-
self.col.compact()
189+
for col in self.collections.values():
190+
col.compact()
182191

183192
def flush(self):
184-
self.col.flush(_async=True)
193+
for col in self.collections.values():
194+
col.flush(_async=True)
185195

186196
def close(self):
187197
self.flush()

0 commit comments

Comments
 (0)