Skip to content

Commit a76d8d6

Browse files
committed
add redis class for vector
1 parent b4437f4 commit a76d8d6

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,141 @@
11
# -*- coding: utf-8 -*-
2+
from typing import List
3+
4+
import numpy as np
5+
from typing import List
6+
7+
import numpy as np
8+
from modelcache.manager.vector_data.base import VectorBase, VectorData
9+
# from modelcache.utils import import_redis
10+
# from modelcache.utils.log import gptcache_log
11+
12+
# import_redis()
13+
#
14+
# # pylint: disable=C0413
15+
# from redis.commands.search.indexDefinition import IndexDefinition, IndexType
16+
# from redis.commands.search.query import Query
17+
# from redis.commands.search.field import TagField, VectorField
18+
# from redis.client import Redis
19+
20+
21+
class RedisVectorStore(VectorBase):
22+
""" vector store: Redis
23+
24+
:param host: redis host, defaults to "localhost".
25+
:type host: str
26+
:param port: redis port, defaults to "6379".
27+
:type port: str
28+
:param username: redis username, defaults to "".
29+
:type username: str
30+
:param password: redis password, defaults to "".
31+
:type password: str
32+
:param dimension: the dimension of the vector, defaults to 0.
33+
:type dimension: int
34+
:param collection_name: the name of the index for Redis, defaults to "gptcache".
35+
:type collection_name: str
36+
:param top_k: the number of the vectors results to return, defaults to 1.
37+
:type top_k: int
38+
39+
Example:
40+
.. code-block:: python
41+
42+
from gptcache.manager import VectorBase
43+
44+
vector_base = VectorBase("redis", dimension=10)
45+
"""
46+
def __init__(
47+
self,
48+
host: str = "localhost",
49+
port: str = "6379",
50+
username: str = "",
51+
password: str = "",
52+
dimension: int = 0,
53+
collection_name: str = "gptcache",
54+
top_k: int = 1,
55+
namespace: str = "",
56+
):
57+
self._client = Redis(
58+
host=host, port=int(port), username=username, password=password
59+
)
60+
self.top_k = top_k
61+
self.dimension = dimension
62+
self.collection_name = collection_name
63+
self.namespace = namespace
64+
self.doc_prefix = f"{self.namespace}doc:" # Prefix with the specified namespace
65+
self._create_collection(collection_name)
66+
67+
def _check_index_exists(self, index_name: str) -> bool:
68+
"""Check if Redis index exists."""
69+
try:
70+
self._client.ft(index_name).info()
71+
except: # pylint: disable=W0702
72+
gptcache_log.info("Index does not exist")
73+
return False
74+
gptcache_log.info("Index already exists")
75+
return True
76+
77+
def _create_collection(self, collection_name):
78+
if self._check_index_exists(collection_name):
79+
gptcache_log.info(
80+
"The %s already exists, and it will be used directly", collection_name
81+
)
82+
else:
83+
schema = (
84+
TagField("tag"), # Tag Field Name
85+
VectorField(
86+
"vector", # Vector Field Name
87+
"FLAT",
88+
{ # Vector Index Type: FLAT or HNSW
89+
"TYPE": "FLOAT32", # FLOAT32 or FLOAT64
90+
"DIM": self.dimension, # Number of Vector Dimensions
91+
"DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric
92+
},
93+
),
94+
)
95+
definition = IndexDefinition(
96+
prefix=[self.doc_prefix], index_type=IndexType.HASH
97+
)
98+
99+
# create Index
100+
self._client.ft(collection_name).create_index(
101+
fields=schema, definition=definition
102+
)
103+
104+
def mul_add(self, datas: List[VectorData]):
105+
pipe = self._client.pipeline()
106+
107+
for data in datas:
108+
key: int = data.id
109+
obj = {
110+
"vector": data.data.astype(np.float32).tobytes(),
111+
}
112+
pipe.hset(f"{self.doc_prefix}{key}", mapping=obj)
113+
114+
pipe.execute()
115+
116+
def search(self, data: np.ndarray, top_k: int = -1):
117+
query = (
118+
Query(
119+
f"*=>[KNN {top_k if top_k > 0 else self.top_k} @vector $vec as score]"
120+
)
121+
.sort_by("score")
122+
.return_fields("id", "score")
123+
.paging(0, top_k if top_k > 0 else self.top_k)
124+
.dialect(2)
125+
)
126+
query_params = {"vec": data.astype(np.float32).tobytes()}
127+
results = (
128+
self._client.ft(self.collection_name)
129+
.search(query, query_params=query_params)
130+
.docs
131+
)
132+
return [(float(result.score), int(result.id[len(self.doc_prefix):])) for result in results]
133+
134+
def rebuild(self, ids=None) -> bool:
135+
pass
136+
137+
def delete(self, ids) -> None:
138+
pipe = self._client.pipeline()
139+
for data_id in ids:
140+
pipe.delete(f"{self.doc_prefix}{data_id}")
141+
pipe.execute()

0 commit comments

Comments
 (0)