Skip to content

Commit 62be825

Browse files
committed
add modelcache mm ability
1 parent c7bedff commit 62be825

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1593
-607
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,4 @@ dmypy.json
138138
**/modelcache_serving.py
139139
**/maya_embedding_service
140140

141-
#*.ini
141+
*.ini

modelcache/manager/scalar_data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create(self):
9595
pass
9696

9797
@abstractmethod
98-
def batch_insert(self, all_data: List[CacheData]):
98+
def batch_iat_insert(self, all_data: List[CacheData]):
9999
pass
100100

101101
@abstractmethod

modelcache_mm/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
"""
3-
Alipay.com Inc.
4-
Copyright (c) 2004-2023 All Rights Reserved.
5-
------------------------------------------------------
6-
File Name : __init__.py.py
7-
Author : fuhui.phe
8-
Create Time : 2024/4/17 10:53
9-
Description : description what the main function of this file
10-
Change Activity:
11-
version0 : 2024/4/17 10:53 by fuhui.phe init
12-
"""
2+
from modelcache_mm.core import Cache
3+
from modelcache_mm.core import cache
4+
from modelcache_mm.config import Config
5+
import modelcache_mm.adapter

modelcache_mm/adapter_mm/adapter.py renamed to modelcache_mm/adapter/adapter.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# -*- coding: utf-8 -*-
22
import logging
33

4-
from modelcache_mm.adapter_mm.adapter_query import adapt_query
5-
from modelcache_mm.adapter_mm.adapter_insert import adapt_insert
6-
from modelcache_mm.adapter_mm.adapter_remove import adapt_remove
7-
from modelcache_mm.adapter_mm.adapter_register import adapt_register
4+
from modelcache_mm.adapter.adapter_query import adapt_query
5+
from modelcache_mm.adapter.adapter_insert import adapt_insert
6+
from modelcache_mm.adapter.adapter_remove import adapt_remove
7+
from modelcache_mm.adapter.adapter_register import adapt_register
88

99

1010
class ChatCompletion(object):
1111
"""Openai ChatCompletion Wrapper"""
1212
@classmethod
13-
def create_mm_query(cls, *args, **kwargs):
13+
def create_query(cls, *args, **kwargs):
1414
def cache_data_convert(cache_data, cache_query):
1515
return construct_resp_from_cache(cache_data, cache_query)
1616
try:
@@ -20,10 +20,10 @@ def cache_data_convert(cache_data, cache_query):
2020
**kwargs
2121
)
2222
except Exception as e:
23-
return str(e)
24-
23+
# return str(e)
24+
raise e
2525
@classmethod
26-
def create_mm_insert(cls, *args, **kwargs):
26+
def create_insert(cls, *args, **kwargs):
2727
try:
2828
return adapt_insert(
2929
*args,
@@ -34,18 +34,17 @@ def create_mm_insert(cls, *args, **kwargs):
3434
raise e
3535

3636
@classmethod
37-
def create_mm_remove(cls, *args, **kwargs):
37+
def create_remove(cls, *args, **kwargs):
3838
try:
3939
return adapt_remove(
4040
*args,
4141
**kwargs
4242
)
4343
except Exception as e:
44-
logging.info('adapt_remove_e: {}'.format(e))
45-
return str(e)
44+
raise e
4645

4746
@classmethod
48-
def create_mm_register(cls, *args, **kwargs):
47+
def create_register(cls, *args, **kwargs):
4948
try:
5049
return adapt_register(
5150
*args,

modelcache_mm/adapter_mm/adapter_insert.py renamed to modelcache_mm/adapter/adapter_insert.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import requests
44
import base64
55
import numpy as np
6-
from modelcache import cache
7-
from modelcache.utils.error import NotInitError
8-
from modelcache.utils.time import time_cal
6+
from modelcache_mm import cache
7+
from modelcache_mm.utils.error import NotInitError
8+
from modelcache_mm.utils.time import time_cal
99

1010

1111
def adapt_insert(*args, **kwargs):
@@ -19,7 +19,7 @@ def adapt_insert(*args, **kwargs):
1919
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
2020
context = kwargs.pop("cache_context", {})
2121
embedding_data = None
22-
pre_embedding_data_dict = chat_cache.mm_insert_pre_embedding_func(
22+
pre_embedding_data_dict = chat_cache.insert_pre_embedding_func(
2323
kwargs,
2424
extra_param=context.get("pre_embedding_func", None),
2525
prompts=chat_cache.config.prompts,
@@ -84,7 +84,6 @@ def adapt_insert(*args, **kwargs):
8484
mm_type = 'text'
8585
else:
8686
raise ValueError('maya embedding service return both empty list, please check!')
87-
8887
print('embedding_data: {}'.format(embedding_data))
8988
chat_cache.data_manager.save(
9089
pre_embedding_text,
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
import requests
4+
import numpy as np
5+
import base64
6+
from modelcache_mm import cache
7+
from modelcache_mm.utils.error import NotInitError
8+
from modelcache_mm.utils.error import MultiTypeError
9+
from modelcache_mm.utils.time import time_cal
10+
11+
12+
def adapt_query(cache_data_convert, *args, **kwargs):
13+
chat_cache = kwargs.pop("cache_obj", cache)
14+
scope = kwargs.pop("scope", None)
15+
model = scope['model']
16+
if not chat_cache.has_init:
17+
raise NotInitError()
18+
19+
cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
20+
context = kwargs.pop("cache_context", {})
21+
cache_factor = kwargs.pop("cache_factor", 1.0)
22+
23+
pre_embedding_data_dict = chat_cache.query_pre_embedding_func(
24+
kwargs,
25+
extra_param=context.get("pre_embedding_func", None),
26+
prompts=chat_cache.config.prompts,
27+
)
28+
29+
pre_embedding_text = '###'.join(pre_embedding_data_dict['text'])
30+
pre_embedding_image_raw = pre_embedding_data_dict['imageRaw']
31+
pre_embedding_image_url = pre_embedding_data_dict['imageUrl']
32+
pre_multi_type = pre_embedding_data_dict['multiType']
33+
# print('pre_embedding_image_url: {}'.format(pre_embedding_image_url))
34+
# print('pre_embedding_text: {}'.format(pre_embedding_text))
35+
36+
# 判断逻辑
37+
if pre_multi_type == 'IMG_TEXT':
38+
if pre_embedding_image_raw and pre_embedding_image_url:
39+
raise ValueError(
40+
"Both pre_embedding_imageUrl and pre_embedding_imageRaw cannot be non-empty at the same time.")
41+
if pre_embedding_image_url:
42+
url_start_time = time.time()
43+
response = requests.get(pre_embedding_image_url)
44+
image_data = response.content
45+
pre_embedding_image = base64.b64encode(image_data).decode('utf-8')
46+
get_image_time = '{}s'.format(round(time.time() - url_start_time, 2))
47+
print('get_image_time: {}'.format(get_image_time))
48+
elif pre_embedding_image_raw:
49+
pre_embedding_image = pre_embedding_image_raw
50+
else:
51+
raise ValueError(
52+
"Both pre_embedding_imageUrl and pre_embedding_imageRaw are empty. Please provide at least one.")
53+
data_dict = {'text': [pre_embedding_text], 'image': pre_embedding_image}
54+
# print('data_dict: {}'.format(data_dict))
55+
elif pre_multi_type == 'TEXT':
56+
data_dict = {'text': [pre_embedding_text], 'image': None}
57+
else:
58+
raise MultiTypeError
59+
# print('data_dict: {}'.format(data_dict))
60+
61+
embedding_data = None
62+
mm_type = None
63+
if cache_enable:
64+
if pre_multi_type == 'IMG_TEXT':
65+
embedding_data_resp = time_cal(
66+
chat_cache.embedding_concurrent_func,
67+
func_name="iat_embedding",
68+
report_func=chat_cache.report.embedding,
69+
)(data_dict)
70+
else:
71+
embedding_data_resp = time_cal(
72+
chat_cache.embedding_func,
73+
func_name="iat_embedding",
74+
report_func=chat_cache.report.embedding,
75+
)(data_dict)
76+
image_embeddings = embedding_data_resp['image_embedding']
77+
text_embeddings = embedding_data_resp['text_embeddings']
78+
79+
if len(image_embeddings) > 0 and len(image_embeddings) > 0:
80+
image_embedding = np.array(image_embeddings[0])
81+
text_embedding = np.array(text_embeddings[0])
82+
embedding_data = np.concatenate((image_embedding, text_embedding))
83+
mm_type = 'mm'
84+
elif len(image_embeddings) > 0:
85+
image_embedding = np.array(image_embeddings[0])
86+
embedding_data = image_embedding
87+
mm_type = 'image'
88+
elif len(text_embeddings) > 0:
89+
text_embedding = np.array(text_embeddings[0])
90+
embedding_data = text_embedding
91+
mm_type = 'text'
92+
else:
93+
raise ValueError('maya embedding service return both empty list, please check!')
94+
95+
if cache_enable:
96+
cache_data_list = time_cal(
97+
chat_cache.data_manager.search,
98+
func_name="vector_search",
99+
report_func=chat_cache.report.search,
100+
)(
101+
embedding_data,
102+
extra_param=context.get("search_func", None),
103+
top_k=kwargs.pop("top_k", -1),
104+
model=model,
105+
mm_type=pre_multi_type,
106+
)
107+
108+
cache_answers = []
109+
cache_questions = []
110+
cache_image_urls = []
111+
cache_image_ids = []
112+
cache_ids = []
113+
similarity_threshold = chat_cache.config.similarity_threshold
114+
similarity_threshold_long = chat_cache.config.similarity_threshold_long
115+
116+
min_rank, max_rank = chat_cache.similarity_evaluation.range()
117+
rank_threshold = (max_rank - min_rank) * similarity_threshold * cache_factor
118+
rank_threshold_long = (max_rank - min_rank) * similarity_threshold_long * cache_factor
119+
rank_threshold = (
120+
max_rank
121+
if rank_threshold > max_rank
122+
else min_rank
123+
if rank_threshold < min_rank
124+
else rank_threshold
125+
)
126+
rank_threshold_long = (
127+
max_rank
128+
if rank_threshold_long > max_rank
129+
else min_rank
130+
if rank_threshold_long < min_rank
131+
else rank_threshold_long
132+
)
133+
134+
if cache_data_list is None or len(cache_data_list) == 0:
135+
rank_pre = -1.0
136+
else:
137+
cache_data_dict = {'search_result': cache_data_list[0]}
138+
rank_pre = chat_cache.similarity_evaluation.evaluation(
139+
None,
140+
cache_data_dict,
141+
extra_param=context.get("evaluation_func", None),
142+
)
143+
144+
print('rank_pre: {}'.format(rank_pre))
145+
print('rank_threshold: {}'.format(rank_threshold))
146+
if rank_pre < rank_threshold:
147+
return
148+
149+
for cache_data in cache_data_list:
150+
print('cache_data: {}'.format(cache_data))
151+
primary_id = cache_data[1]
152+
ret = chat_cache.data_manager.get_scalar_data(
153+
cache_data, extra_param=context.get("get_scalar_data", None)
154+
)
155+
if ret is None:
156+
continue
157+
158+
if "deps" in context and hasattr(ret.question, "deps"):
159+
eval_query_data = {
160+
"question": context["deps"][0]["data"],
161+
"embedding": None
162+
}
163+
eval_cache_data = {
164+
"question": ret.question.deps[0].data,
165+
"answer": ret.answers[0].answer,
166+
"search_result": cache_data,
167+
"embedding": None,
168+
}
169+
else:
170+
eval_query_data = {
171+
"question": pre_embedding_text,
172+
"embedding": embedding_data,
173+
}
174+
175+
eval_cache_data = {
176+
"question": ret[0],
177+
"image_url": ret[1],
178+
"image_raw": ret[2],
179+
"answer": ret[3],
180+
"search_result": cache_data,
181+
"embedding": None
182+
}
183+
rank = chat_cache.similarity_evaluation.evaluation(
184+
eval_query_data,
185+
eval_cache_data,
186+
extra_param=context.get("evaluation_func", None),
187+
)
188+
print('rank_threshold: {}'.format(rank_threshold))
189+
print('rank_threshold_long: {}'.format(rank_threshold_long))
190+
print('rank: {}'.format(rank))
191+
192+
if len(pre_embedding_text) <= 50:
193+
if rank_threshold <= rank:
194+
cache_answers.append((rank, ret[3]))
195+
cache_image_urls.append((rank, ret[1]))
196+
cache_image_ids.append((rank, ret[2]))
197+
cache_questions.append((rank, ret[0]))
198+
cache_ids.append((rank, primary_id))
199+
else:
200+
if rank_threshold_long <= rank:
201+
cache_answers.append((rank, ret[3]))
202+
cache_image_urls.append((rank, ret[1]))
203+
cache_image_ids.append((rank, ret[2]))
204+
cache_questions.append((rank, ret[0]))
205+
cache_ids.append((rank, primary_id))
206+
207+
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
208+
cache_image_urls = sorted(cache_image_urls, key=lambda x: x[0], reverse=True)
209+
cache_image_ids = sorted(cache_image_ids, key=lambda x: x[0], reverse=True)
210+
cache_questions = sorted(cache_questions, key=lambda x: x[0], reverse=True)
211+
cache_ids = sorted(cache_ids, key=lambda x: x[0], reverse=True)
212+
213+
print('cache_answers: {}'.format(cache_answers))
214+
215+
if len(cache_answers) != 0:
216+
return_message = chat_cache.post_process_messages_func(
217+
[t[1] for t in cache_answers]
218+
)
219+
return_image_url = chat_cache.post_process_messages_func(
220+
[t[1] for t in cache_image_urls]
221+
)
222+
return_image_id = chat_cache.post_process_messages_func(
223+
[t[1] for t in cache_image_ids]
224+
)
225+
return_query = chat_cache.post_process_messages_func(
226+
[t[1] for t in cache_questions]
227+
)
228+
return_id = chat_cache.post_process_messages_func(
229+
[t[1] for t in cache_ids]
230+
)
231+
# 更新命中次数
232+
try:
233+
chat_cache.data_manager.update_hit_count(return_id)
234+
except Exception:
235+
print('update_hit_count except, please check!')
236+
237+
chat_cache.report.hint_cache()
238+
return_query_dict = {"image_url": return_image_url, "image_id": return_image_id, "question": return_query}
239+
return cache_data_convert(return_message, return_query_dict)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# -*- coding: utf-8 -*-
2-
from modelcache import cache
2+
from modelcache_mm import cache
33

44

55
def adapt_register(*args, **kwargs):
66
chat_cache = kwargs.pop("cache_obj", cache)
77
model = kwargs.pop("model", None)
8-
mm_type = kwargs.pop("mm_type", None)
8+
type = kwargs.pop("type", None)
99
if model is None or len(model) == 0:
1010
return ValueError('')
1111

12-
print('mm_type: {}'.format(mm_type))
12+
print('type: {}'.format(type))
1313
print('model: {}'.format(model))
14-
register_resp = chat_cache.data_manager.create_index(model, mm_type)
14+
register_resp = chat_cache.data_manager.create_index(model, type)
1515
print('register_resp: {}'.format(register_resp))
1616
return register_resp

0 commit comments

Comments
 (0)