Skip to content

Commit 59def83

Browse files
committed
add redis vector
1 parent 5af9cf8 commit 59def83

File tree

2 files changed

+185
-1
lines changed

2 files changed

+185
-1
lines changed

flask4modelcache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ def response_hitquery(cache_resp):
3838
mysql_config.read('modelcache/config/mysql_config.ini')
3939
milvus_config = configparser.ConfigParser()
4040
milvus_config.read('modelcache/config/milvus_config.ini')
41+
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
42+
# VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
43+
4144
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
42-
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
45+
VectorBase("redis", dimension=data2vec.dimension, milvus_config=milvus_config))
4346

4447

4548
cache.init(
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import List
3+
import numpy as np
4+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
5+
from redis.commands.search.query import Query
6+
from redis.commands.search.field import TagField, VectorField, NumericField
7+
from redis.client import Redis
8+
9+
from gptcache.manager.vector_data.base import VectorBase, VectorData
10+
from gptcache.utils import import_redis
11+
from gptcache.utils.log import gptcache_log
12+
from gptcache.utils.collection_util import get_collection_name
13+
from gptcache.utils.collection_util import get_collection_prefix
14+
import_redis()
15+
16+
17+
class RedisVectorStore(VectorBase):
18+
def __init__(
19+
self,
20+
host: str = "localhost",
21+
port: str = "6379",
22+
username: str = "",
23+
password: str = "",
24+
table_suffix: str = "",
25+
dimension: int = 0,
26+
collection_prefix: str = "gptcache",
27+
top_k: int = 1,
28+
namespace: str = "",
29+
):
30+
if dimension <= 0:
31+
raise ValueError(
32+
f"invalid `dim` param: {dimension} in the Milvus vector store."
33+
)
34+
self._client = Redis(
35+
host=host, port=int(port), username=username, password=password
36+
)
37+
self.top_k = top_k
38+
self.dimension = dimension
39+
self.collection_prefix = collection_prefix
40+
self.table_suffix = table_suffix
41+
self.namespace = namespace
42+
self.doc_prefix = f"{self.namespace}doc:" # Prefix with the specified namespace
43+
# self._create_collection(collection_name)
44+
45+
def _check_index_exists(self, index_name: str) -> bool:
46+
"""Check if Redis index exists."""
47+
try:
48+
self._client.ft(index_name).info()
49+
except: # pylint: disable=W0702
50+
gptcache_log.info("Index does not exist")
51+
return False
52+
gptcache_log.info("Index already exists")
53+
return True
54+
55+
def create_collection(self, collection_name, index_prefix):
56+
dimension = self.dimension
57+
print('dimension: {}'.format(dimension))
58+
if self._check_index_exists(collection_name):
59+
gptcache_log.info(
60+
"The %s already exists, and it will be used directly", collection_name
61+
)
62+
return 'already_exists'
63+
else:
64+
# id_field_name = collection_name + '_' + "id"
65+
# embedding_field_name = collection_name + '_' + "vec"
66+
id_field_name = "data_id"
67+
embedding_field_name = "data_vector"
68+
69+
id = NumericField(name=id_field_name)
70+
embedding = VectorField(embedding_field_name,
71+
"HNSW", {
72+
"TYPE": "FLOAT32",
73+
"DIM": dimension,
74+
"DISTANCE_METRIC": "L2",
75+
"INITIAL_CAP": 1000,
76+
}
77+
)
78+
fields = [id, embedding]
79+
# definition = IndexDefinition(index_type=IndexType.HASH)
80+
definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)
81+
82+
# create Index
83+
self._client.ft(collection_name).create_index(
84+
fields=fields, definition=definition
85+
)
86+
return 'create_success'
87+
88+
def mul_add(self, datas: List[VectorData], model=None):
89+
# pipe = self._client.pipeline()
90+
for data in datas:
91+
id: int = data.id
92+
embedding = data.data.astype(np.float32).tobytes()
93+
# id_field_name = collection_name + '_' + "id"
94+
# embedding_field_name = collection_name + '_' + "vec"
95+
id_field_name = "data_id"
96+
embedding_field_name = "data_vector"
97+
obj = {id_field_name: id, embedding_field_name: embedding}
98+
index_prefix = get_collection_prefix(model, self.table_suffix)
99+
self._client.hset(f"{index_prefix}{id}", mapping=obj)
100+
101+
# obj = {
102+
# "vector": data.data.astype(np.float32).tobytes(),
103+
# }
104+
# pipe.hset(f"{self.doc_prefix}{key}", mapping=obj)
105+
# pipe.execute()
106+
107+
def search(self, data: np.ndarray, top_k: int = -1, model=None):
108+
collection_name = get_collection_name(model, self.table_suffix)
109+
print('collection_name: {}'.format(collection_name))
110+
id_field_name = "data_id"
111+
embedding_field_name = "data_vector"
112+
113+
base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]'
114+
query = (
115+
Query(base_query)
116+
.sort_by("distance")
117+
.return_fields(id_field_name, "distance")
118+
.dialect(2)
119+
)
120+
121+
query_params = {"vector": data.astype(np.float32).tobytes()}
122+
# print('query_params: {}'.format(query_params))
123+
results = (
124+
self._client.ft(collection_name)
125+
.search(query, query_params=query_params)
126+
.docs
127+
)
128+
print('results: {}'.format(results))
129+
for i, doc in enumerate(results):
130+
print('doc: {}'.format(doc))
131+
print("id_field_name", getattr(doc, id_field_name), ", distance: ", doc.distance)
132+
return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]
133+
134+
def rebuild(self, ids=None) -> bool:
135+
pass
136+
137+
def rebuild_col(self, model):
138+
resp_info = 'failed'
139+
if len(self.table_suffix) == 0:
140+
raise ValueError('table_suffix is none error,please check!')
141+
142+
collection_name_model = get_collection_name(model, self.table_suffix)
143+
print('collection_name_model: {}'.format(collection_name_model))
144+
if self._check_index_exists(collection_name_model):
145+
try:
146+
self._client.ft(collection_name_model).dropindex(delete_documents=True)
147+
except Exception as e:
148+
raise ValueError(str(e))
149+
try:
150+
index_prefix = get_collection_prefix(model, self.table_suffix)
151+
self.create_collection(collection_name_model, index_prefix)
152+
except Exception as e:
153+
raise ValueError(str(e))
154+
return 'rebuild success'
155+
156+
# print('remove collection_name_model: {}'.format(collection_name_model))
157+
# try:
158+
# self._client.ft(collection_name_model).dropindex(delete_documents=True)
159+
# resp_info = 'rebuild success'
160+
# except Exception as e:
161+
# print('exception: {}'.format(e))
162+
# resp_info = 'create only'
163+
# try:
164+
# self.create_collection(collection_name_model)
165+
# except Exception as e:
166+
# raise ValueError(str(e))
167+
# return resp_info
168+
169+
def delete(self, ids) -> None:
170+
pipe = self._client.pipeline()
171+
for data_id in ids:
172+
pipe.delete(f"{self.doc_prefix}{data_id}")
173+
pipe.execute()
174+
175+
def create(self, model=None):
176+
collection_name = get_collection_name(model, self.table_suffix)
177+
index_prefix = get_collection_prefix(model, self.table_suffix)
178+
return self.create_collection(collection_name, index_prefix)
179+
180+
def get_collection_by_name(self, collection_name, table_suffix):
181+
pass

0 commit comments

Comments
 (0)