Skip to content

Commit c4f9253

Browse files
Added handling for bulk insert
Added a flag for normalize because the model we use already normalizes Co-authored-by: olgaoznovich <[email protected]> Co-authored-by: Yuval-Roth <[email protected]>
1 parent 82ba867 commit c4f9253

File tree

9 files changed

+51
-38
lines changed

9 files changed

+51
-38
lines changed

docker-compose.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ services:
1414
- ./data/mysql/db:/var/lib/mysql
1515
- ./data/mysql/my.cnf:/etc/mysql/conf.d/my.cnf
1616
- ./data/mysql/init:/docker-entrypoint-initdb.d
17-
restart: on-failure
17+
# restart: on-failure
1818
networks:
1919
- modelcache
2020

@@ -36,15 +36,15 @@ services:
3636
- 19530:19530
3737
- 9091:9091
3838
- 2379:2379
39-
healthcheck:
40-
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
41-
interval: 30s
42-
start_period: 90s
43-
timeout: 20s
44-
retries: 3
39+
# healthcheck:
40+
# test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
41+
# interval: 30s
42+
# start_period: 90s
43+
# timeout: 20s
44+
# retries: 3
4545
networks:
4646
- modelcache
47-
restart: on-failure
47+
# restart: on-failure
4848
command: milvus run standalone
4949

5050
# modelcache:

flask4modelcache.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
from modelcache import cache
88
from modelcache.adapter import adapter
9-
from modelcache.manager import CacheBase, VectorBase, get_data_manager
9+
from modelcache.manager import CacheBase, VectorBase, get_data_manager, data_manager
1010
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
1111
from modelcache.processor.pre import query_multi_splicing
1212
from modelcache.processor.pre import insert_multi_splicing
@@ -30,8 +30,10 @@ def save_query_info(result, model, query, delta_time_log):
3030
def response_hitquery(cache_resp):
3131
return cache_resp['hitQuery']
3232

33-
3433
data2vec = Data2VecAudio()
34+
embedding_func = data2vec.to_embeddings
35+
dimension = data2vec.dimension
36+
3537
mysql_config = configparser.ConfigParser()
3638
mysql_config.read('modelcache/config/mysql_config.ini')
3739

@@ -48,7 +50,7 @@ def response_hitquery(cache_resp):
4850
# chromadb_config.read('modelcache/config/chromadb_config.ini')
4951

5052
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
51-
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
53+
VectorBase("milvus", dimension=dimension, milvus_config=milvus_config))
5254

5355

5456
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
@@ -57,9 +59,8 @@ def response_hitquery(cache_resp):
5759
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
5860
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))
5961

60-
6162
cache.init(
62-
embedding_func=data2vec.to_embeddings,
63+
embedding_func=embedding_func,
6364
data_manager=data_manager,
6465
similarity_evaluation=SearchDistanceEvaluation(),
6566
query_pre_embedding_func=query_multi_splicing,

model/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.tflite
2+
text2vec-base-chinese/*

model/clone_model_repository.bat

Lines changed: 0 additions & 2 deletions
This file was deleted.

model/download_bert_embedder.bat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
curl -o embedder.tflite https://storage.googleapis.com/mediapipe-models/text_embedder/bert_embedder/float32/1/bert_embedder.tflite

modelcache/adapter/adapter_insert.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,31 @@ def adapt_insert(*args, **kwargs):
1414
raise NotInitError()
1515
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
1616
context = kwargs.pop("cache_context", {})
17-
embedding_data = None
18-
pre_embedding_data = chat_cache.insert_pre_embedding_func(
19-
kwargs,
20-
extra_param=context.get("pre_embedding_func", None),
21-
prompts=chat_cache.config.prompts,
22-
)
2317
chat_info = kwargs.pop("chat_info", [])
24-
llm_data = chat_info[-1]['answer']
2518

26-
if cache_enable:
19+
pre_embedding_data_list = []
20+
embedding_data_list = []
21+
llm_data_list = []
22+
23+
for row in chat_info:
24+
pre_embedding_data = chat_cache.insert_pre_embedding_func(
25+
row,
26+
extra_param=context.get("pre_embedding_func", None),
27+
prompts=chat_cache.config.prompts,
28+
)
29+
pre_embedding_data_list.append(pre_embedding_data)
30+
llm_data_list.append(row['answer'])
2731
embedding_data = time_cal(
2832
chat_cache.embedding_func,
2933
func_name="embedding",
3034
report_func=chat_cache.report.embedding,
3135
)(pre_embedding_data)
36+
embedding_data_list.append(embedding_data)
3237

3338
chat_cache.data_manager.save(
34-
pre_embedding_data,
35-
llm_data,
36-
embedding_data,
39+
pre_embedding_data_list,
40+
llm_data_list,
41+
embedding_data_list,
3742
model=model,
3843
extra_param=context.get("save_func", None)
3944
)

modelcache/manager/data_manager.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import cachetools
88
from abc import abstractmethod, ABCMeta
99
from typing import List, Any, Optional, Union
10+
11+
from numpy import ndarray
12+
1013
from modelcache.manager.scalar_data.base import (
1114
CacheStorage,
1215
CacheData,
@@ -21,6 +24,7 @@
2124
from modelcache.manager.eviction_manager import EvictionManager
2225
from modelcache.utils.log import modelcache_log
2326

27+
NORMALIZE = True
2428

2529
class DataManager(metaclass=ABCMeta):
2630
"""DataManager manage the cache data, including save and search"""
@@ -158,9 +162,9 @@ def __init__(
158162
self.v = v
159163
self.o = o
160164

161-
def save(self, question, answer, embedding_data, **kwargs):
165+
def save(self, questions: List[any], answers: List[any], embedding_datas: List[any], **kwargs):
162166
model = kwargs.pop("model", None)
163-
self.import_data([question], [answer], [embedding_data], model)
167+
self.import_data(questions, answers, embedding_datas, model)
164168

165169
def save_query_resp(self, query_resp_dict, **kwargs):
166170
save_query_start_time = time.time()
@@ -197,9 +201,10 @@ def import_data(
197201
raise ParamError("Make sure that all parameters have the same length")
198202
cache_datas = []
199203

200-
embedding_datas = [
201-
normalize(embedding_data) for embedding_data in embedding_datas
202-
]
204+
if NORMALIZE:
205+
embedding_datas = [
206+
normalize(embedding_data) for embedding_data in embedding_datas
207+
]
203208

204209
for i, embedding_data in enumerate(embedding_datas):
205210
if self.o is not None:
@@ -212,11 +217,9 @@ def import_data(
212217
cache_datas.append([ans, question, embedding_data, model])
213218

214219
ids = self.s.batch_insert(cache_datas)
220+
datas_ = [VectorData(id=ids[i], data=embedding_data.astype("float32")) for i, embedding_data in enumerate(embedding_datas)]
215221
self.v.mul_add(
216-
[
217-
VectorData(id=ids[i], data=embedding_data)
218-
for i, embedding_data in enumerate(embedding_datas)
219-
],
222+
datas_,
220223
model
221224

222225
)
@@ -235,7 +238,8 @@ def hit_cache_callback(self, res_data, **kwargs):
235238

236239
def search(self, embedding_data, **kwargs):
237240
model = kwargs.pop("model", None)
238-
embedding_data = normalize(embedding_data)
241+
if NORMALIZE:
242+
embedding_data = normalize(embedding_data)
239243
top_k = kwargs.get("top_k", -1)
240244
return self.v.search(data=embedding_data, top_k=top_k, model=model)
241245

modelcache/processor/pre.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
6464

6565

6666
def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
67-
insert_query_list = data.get("chat_info")[-1]['query']
67+
insert_query_list = data['query']
6868
return multi_splicing(insert_query_list)
6969

7070

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ chromadb==0.5.23
1818
elasticsearch==7.10.0
1919
snowflake-id==1.0.2
2020
flagembedding==1.3.4
21-
cryptography==45.0.2
21+
cryptography==45.0.2
22+
mediapipe==0.10.21
23+
protobuf==4.25.8

0 commit comments

Comments
 (0)