Skip to content

Commit 10c8ae0

Browse files
committed
multimodel cache
1 parent a23f8c0 commit 10c8ae0

33 files changed

+2390
-81
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,6 @@ dmypy.json
136136
/flask_server
137137
*.bin
138138
**/modelcache_serving.py
139-
*.ini
139+
*.ini
140+
141+
**/maya_embedding_service

modelcache/adapter_mm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: utf-8 -*-

modelcache/adapter_mm/adapter.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding: utf-8 -*-
2+
import logging
3+
4+
from modelcache.adapter_mm.adapter_query import adapt_query
5+
from modelcache.adapter_mm.adapter_insert import adapt_insert
6+
from modelcache.adapter.adapter_remove import adapt_remove
7+
from modelcache.adapter.adapter_register import adapt_register
8+
9+
10+
class ChatCompletion(object):
11+
"""Openai ChatCompletion Wrapper"""
12+
@classmethod
13+
def create_mm_query(cls, *args, **kwargs):
14+
def cache_data_convert(cache_data, cache_query):
15+
return construct_resp_from_cache(cache_data, cache_query)
16+
try:
17+
return adapt_query(
18+
cache_data_convert,
19+
*args,
20+
**kwargs
21+
)
22+
except Exception as e:
23+
return str(e)
24+
25+
@classmethod
26+
def create_mm_insert(cls, *args, **kwargs):
27+
try:
28+
return adapt_insert(
29+
*args,
30+
**kwargs
31+
)
32+
except Exception as e:
33+
return str(e)
34+
35+
@classmethod
36+
def create_mm_remove(cls, *args, **kwargs):
37+
try:
38+
return adapt_remove(
39+
*args,
40+
**kwargs
41+
)
42+
except Exception as e:
43+
logging.info('adapt_remove_e: {}'.format(e))
44+
return str(e)
45+
46+
@classmethod
47+
def create_mm_register(cls, *args, **kwargs):
48+
try:
49+
return adapt_register(
50+
*args,
51+
**kwargs
52+
)
53+
except Exception as e:
54+
return str(e)
55+
56+
57+
def construct_resp_from_cache(return_message, return_query):
58+
return {
59+
"modelcache": True,
60+
"hitQuery": return_query,
61+
"data": return_message,
62+
"errorCode": 0
63+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache import cache
3+
from modelcache.utils.error import NotInitError
4+
from modelcache.utils.time import time_cal
5+
6+
7+
def adapt_insert(*args, **kwargs):
8+
chat_cache = kwargs.pop("cache_obj", cache)
9+
model = kwargs.pop("model", None)
10+
require_object_store = kwargs.pop("require_object_store", False)
11+
if require_object_store:
12+
assert chat_cache.data_manager.o, "Object store is required for adapter."
13+
if not chat_cache.has_init:
14+
raise NotInitError()
15+
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
16+
context = kwargs.pop("cache_context", {})
17+
embedding_data = None
18+
pre_embedding_data = chat_cache.insert_pre_embedding_func(
19+
kwargs,
20+
extra_param=context.get("pre_embedding_func", None),
21+
prompts=chat_cache.config.prompts,
22+
)
23+
chat_info = kwargs.pop("chat_info", [])
24+
llm_data = chat_info[-1]['answer']
25+
26+
if cache_enable:
27+
embedding_data = time_cal(
28+
chat_cache.embedding_func,
29+
func_name="embedding",
30+
report_func=chat_cache.report.embedding,
31+
)(pre_embedding_data)
32+
33+
chat_cache.data_manager.save(
34+
pre_embedding_data,
35+
llm_data,
36+
embedding_data,
37+
model=model,
38+
extra_param=context.get("save_func", None)
39+
)
40+
return 'success'
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# -*- coding: utf-8 -*-
2+
import logging
3+
import time
4+
from modelcache import cache
5+
from modelcache.utils.error import NotInitError
6+
from modelcache.utils.time import time_cal
7+
from modelcache.processor.pre import multi_analysis
8+
9+
10+
def adapt_query(cache_data_convert, *args, **kwargs):
11+
chat_cache = kwargs.pop("cache_obj", cache)
12+
scope = kwargs.pop("scope", None)
13+
model = scope['model']
14+
if not chat_cache.has_init:
15+
raise NotInitError()
16+
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
17+
context = kwargs.pop("cache_context", {})
18+
embedding_data = None
19+
cache_factor = kwargs.pop("cache_factor", 1.0)
20+
pre_embedding_data = chat_cache.query_pre_embedding_func(
21+
kwargs,
22+
extra_param=context.get("pre_embedding_func", None),
23+
prompts=chat_cache.config.prompts,
24+
)
25+
26+
if cache_enable:
27+
embedding_data = time_cal(
28+
chat_cache.embedding_func,
29+
func_name="embedding",
30+
report_func=chat_cache.report.embedding,
31+
)(pre_embedding_data)
32+
33+
if cache_enable:
34+
cache_data_list = time_cal(
35+
chat_cache.data_manager.search,
36+
func_name="milvus_search",
37+
report_func=chat_cache.report.search,
38+
)(
39+
embedding_data,
40+
extra_param=context.get("search_func", None),
41+
top_k=kwargs.pop("top_k", -1),
42+
model=model
43+
)
44+
cache_answers = []
45+
cache_questions = []
46+
cache_ids = []
47+
similarity_threshold = chat_cache.config.similarity_threshold
48+
similarity_threshold_long = chat_cache.config.similarity_threshold_long
49+
50+
min_rank, max_rank = chat_cache.similarity_evaluation.range()
51+
rank_threshold = (max_rank - min_rank) * similarity_threshold * cache_factor
52+
rank_threshold_long = (max_rank - min_rank) * similarity_threshold_long * cache_factor
53+
rank_threshold = (
54+
max_rank
55+
if rank_threshold > max_rank
56+
else min_rank
57+
if rank_threshold < min_rank
58+
else rank_threshold
59+
)
60+
rank_threshold_long = (
61+
max_rank
62+
if rank_threshold_long > max_rank
63+
else min_rank
64+
if rank_threshold_long < min_rank
65+
else rank_threshold_long
66+
)
67+
68+
if cache_data_list is None or len(cache_data_list) == 0:
69+
rank_pre = -1.0
70+
else:
71+
cache_data_dict = {'search_result': cache_data_list[0]}
72+
rank_pre = chat_cache.similarity_evaluation.evaluation(
73+
None,
74+
cache_data_dict,
75+
extra_param=context.get("evaluation_func", None),
76+
)
77+
if rank_pre < rank_threshold:
78+
return
79+
80+
for cache_data in cache_data_list:
81+
primary_id = cache_data[1]
82+
start_time = time.time()
83+
ret = chat_cache.data_manager.get_scalar_data(
84+
cache_data, extra_param=context.get("get_scalar_data", None)
85+
)
86+
if ret is None:
87+
continue
88+
89+
if "deps" in context and hasattr(ret.question, "deps"):
90+
eval_query_data = {
91+
"question": context["deps"][0]["data"],
92+
"embedding": None
93+
}
94+
eval_cache_data = {
95+
"question": ret.question.deps[0].data,
96+
"answer": ret.answers[0].answer,
97+
"search_result": cache_data,
98+
"embedding": None,
99+
}
100+
else:
101+
eval_query_data = {
102+
"question": pre_embedding_data,
103+
"embedding": embedding_data,
104+
}
105+
106+
eval_cache_data = {
107+
"question": ret[0],
108+
"answer": ret[1],
109+
"search_result": cache_data,
110+
"embedding": None
111+
}
112+
rank = chat_cache.similarity_evaluation.evaluation(
113+
eval_query_data,
114+
eval_cache_data,
115+
extra_param=context.get("evaluation_func", None),
116+
)
117+
118+
if len(pre_embedding_data) <= 256:
119+
if rank_threshold <= rank:
120+
cache_answers.append((rank, ret[1]))
121+
cache_questions.append((rank, ret[0]))
122+
cache_ids.append((rank, primary_id))
123+
else:
124+
if rank_threshold_long <= rank:
125+
cache_answers.append((rank, ret[1]))
126+
cache_questions.append((rank, ret[0]))
127+
cache_ids.append((rank, primary_id))
128+
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
129+
cache_questions = sorted(cache_questions, key=lambda x: x[0], reverse=True)
130+
cache_ids = sorted(cache_ids, key=lambda x: x[0], reverse=True)
131+
if len(cache_answers) != 0:
132+
return_message = chat_cache.post_process_messages_func(
133+
[t[1] for t in cache_answers]
134+
)
135+
return_query = chat_cache.post_process_messages_func(
136+
[t[1] for t in cache_questions]
137+
)
138+
return_id = chat_cache.post_process_messages_func(
139+
[t[1] for t in cache_ids]
140+
)
141+
# 更新命中次数
142+
try:
143+
chat_cache.data_manager.update_hit_count(return_id)
144+
except Exception:
145+
logging.info('update_hit_count except, please check!')
146+
147+
chat_cache.report.hint_cache()
148+
return cache_data_convert(return_message, return_query)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache import cache
3+
4+
5+
def adapt_register(*args, **kwargs):
6+
chat_cache = kwargs.pop("cache_obj", cache)
7+
model = kwargs.pop("model", None)
8+
if model is None or len(model) == 0:
9+
return ValueError('')
10+
11+
register_resp = chat_cache.data_manager.create_index(model)
12+
print('register_resp: {}'.format(register_resp))
13+
return register_resp
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache import cache
3+
from modelcache.utils.error import NotInitError, RemoveError
4+
5+
6+
def adapt_remove(*args, **kwargs):
7+
chat_cache = kwargs.pop("cache_obj", cache)
8+
model = kwargs.pop("model", None)
9+
remove_type = kwargs.pop("remove_type", None)
10+
require_object_store = kwargs.pop("require_object_store", False)
11+
if require_object_store:
12+
assert chat_cache.data_manager.o, "Object store is required for adapter."
13+
if not chat_cache.has_init:
14+
raise NotInitError()
15+
16+
# delete data
17+
if remove_type == 'delete_by_id':
18+
id_list = kwargs.pop("id_list", [])
19+
resp = chat_cache.data_manager.delete(id_list, model=model)
20+
elif remove_type == 'truncate_by_model':
21+
resp = chat_cache.data_manager.truncate(model)
22+
else:
23+
# resp = "remove_type_error"
24+
raise RemoveError()
25+
return resp
26+

modelcache/core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ class Cache:
1717
def __init__(self):
1818
self.has_init = False
1919
self.cache_enable_func = None
20-
self.query_pre_embedding_func = None
21-
self.insert_pre_embedding_func = None
20+
self.mm_query_pre_embedding_func = None
21+
self.mm_insert_pre_embedding_func = None
2222
self.embedding_func = None
23+
self.embedding_concurrent_func = None
2324
self.data_manager: Optional[DataManager] = None
2425
self.similarity_evaluation: Optional[SimilarityEvaluation] = None
2526
self.post_process_messages_func = None
@@ -30,9 +31,10 @@ def __init__(self):
3031
def init(
3132
self,
3233
cache_enable_func=cache_all,
33-
query_pre_embedding_func=None,
34-
insert_pre_embedding_func=None,
34+
mm_query_pre_embedding_func=None,
35+
mm_insert_pre_embedding_func=None,
3536
embedding_func=string_embedding,
37+
embedding_concurrent_func=string_embedding,
3638
data_manager: DataManager = get_data_manager(),
3739
similarity_evaluation=ExactMatchEvaluation(),
3840
post_process_messages_func=first,
@@ -41,9 +43,10 @@ def init(
4143
):
4244
self.has_init = True
4345
self.cache_enable_func = cache_enable_func
44-
self.query_pre_embedding_func = query_pre_embedding_func
45-
self.insert_pre_embedding_func = insert_pre_embedding_func
46+
self.mm_query_pre_embedding_func = mm_query_pre_embedding_func
47+
self.mm_insert_pre_embedding_func = mm_insert_pre_embedding_func
4648
self.embedding_func = embedding_func
49+
self.embedding_concurrent_func = embedding_concurrent_func
4750
self.data_manager: DataManager = data_manager
4851
self.similarity_evaluation = similarity_evaluation
4952
self.post_process_messages_func = post_process_messages_func

modelcache/manager/vector_data/redis.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,25 @@ def __init__(
2121
port: str = "6379",
2222
username: str = "",
2323
password: str = "",
24-
dimension: int = 0,
24+
# dimension: int = 0,
25+
mm_dimension: int = 0,
26+
i_dimension: int = 0,
27+
t_dimension: int = 0,
2528
top_k: int = 1,
2629
namespace: str = "",
2730
):
28-
if dimension <= 0:
31+
if mm_dimension <= 0:
2932
raise ValueError(
30-
f"invalid `dim` param: {dimension} in the Milvus vector store."
33+
f"invalid `dim` param: {mm_dimension} in the Redis vector store."
3134
)
3235
self._client = Redis(
3336
host=host, port=int(port), username=username, password=password
3437
)
3538
self.top_k = top_k
36-
self.dimension = dimension
39+
# self.dimension = dimension
40+
self.mm_dimension = mm_dimension
41+
self.i_dimension = i_dimension
42+
self.t_dimension = t_dimension
3743
self.namespace = namespace
3844
self.doc_prefix = f"{self.namespace}doc:"
3945

modelcache/manager_mm/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache.manager_mm.scalar_data import CacheBase
3+
from modelcache.manager_mm.vector_data import VectorBase
4+
from modelcache.manager_mm.object_data import ObjectBase
5+
from modelcache.manager_mm.factory import get_data_manager

0 commit comments

Comments
 (0)