1
1
# -*- coding: utf-8 -*-
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from modelcache .manager .vector_data .base import VectorBase , VectorData
9
+ # from modelcache.utils import import_redis
10
+ # from modelcache.utils.log import gptcache_log
11
+
12
+ # import_redis()
13
+ #
14
+ # # pylint: disable=C0413
15
+ # from redis.commands.search.indexDefinition import IndexDefinition, IndexType
16
+ # from redis.commands.search.query import Query
17
+ # from redis.commands.search.field import TagField, VectorField
18
+ # from redis.client import Redis
19
+
20
+
21
+ class RedisVectorStore (VectorBase ):
22
+ """ vector store: Redis
23
+
24
+ :param host: redis host, defaults to "localhost".
25
+ :type host: str
26
+ :param port: redis port, defaults to "6379".
27
+ :type port: str
28
+ :param username: redis username, defaults to "".
29
+ :type username: str
30
+ :param password: redis password, defaults to "".
31
+ :type password: str
32
+ :param dimension: the dimension of the vector, defaults to 0.
33
+ :type dimension: int
34
+ :param collection_name: the name of the index for Redis, defaults to "gptcache".
35
+ :type collection_name: str
36
+ :param top_k: the number of the vectors results to return, defaults to 1.
37
+ :type top_k: int
38
+
39
+ Example:
40
+ .. code-block:: python
41
+
42
+ from gptcache.manager import VectorBase
43
+
44
+ vector_base = VectorBase("redis", dimension=10)
45
+ """
46
+ def __init__ (
47
+ self ,
48
+ host : str = "localhost" ,
49
+ port : str = "6379" ,
50
+ username : str = "" ,
51
+ password : str = "" ,
52
+ dimension : int = 0 ,
53
+ collection_name : str = "gptcache" ,
54
+ top_k : int = 1 ,
55
+ namespace : str = "" ,
56
+ ):
57
+ self ._client = Redis (
58
+ host = host , port = int (port ), username = username , password = password
59
+ )
60
+ self .top_k = top_k
61
+ self .dimension = dimension
62
+ self .collection_name = collection_name
63
+ self .namespace = namespace
64
+ self .doc_prefix = f"{ self .namespace } doc:" # Prefix with the specified namespace
65
+ self ._create_collection (collection_name )
66
+
67
+ def _check_index_exists (self , index_name : str ) -> bool :
68
+ """Check if Redis index exists."""
69
+ try :
70
+ self ._client .ft (index_name ).info ()
71
+ except : # pylint: disable=W0702
72
+ gptcache_log .info ("Index does not exist" )
73
+ return False
74
+ gptcache_log .info ("Index already exists" )
75
+ return True
76
+
77
+ def _create_collection (self , collection_name ):
78
+ if self ._check_index_exists (collection_name ):
79
+ gptcache_log .info (
80
+ "The %s already exists, and it will be used directly" , collection_name
81
+ )
82
+ else :
83
+ schema = (
84
+ TagField ("tag" ), # Tag Field Name
85
+ VectorField (
86
+ "vector" , # Vector Field Name
87
+ "FLAT" ,
88
+ { # Vector Index Type: FLAT or HNSW
89
+ "TYPE" : "FLOAT32" , # FLOAT32 or FLOAT64
90
+ "DIM" : self .dimension , # Number of Vector Dimensions
91
+ "DISTANCE_METRIC" : "COSINE" , # Vector Search Distance Metric
92
+ },
93
+ ),
94
+ )
95
+ definition = IndexDefinition (
96
+ prefix = [self .doc_prefix ], index_type = IndexType .HASH
97
+ )
98
+
99
+ # create Index
100
+ self ._client .ft (collection_name ).create_index (
101
+ fields = schema , definition = definition
102
+ )
103
+
104
+ def mul_add (self , datas : List [VectorData ]):
105
+ pipe = self ._client .pipeline ()
106
+
107
+ for data in datas :
108
+ key : int = data .id
109
+ obj = {
110
+ "vector" : data .data .astype (np .float32 ).tobytes (),
111
+ }
112
+ pipe .hset (f"{ self .doc_prefix } { key } " , mapping = obj )
113
+
114
+ pipe .execute ()
115
+
116
+ def search (self , data : np .ndarray , top_k : int = - 1 ):
117
+ query = (
118
+ Query (
119
+ f"*=>[KNN { top_k if top_k > 0 else self .top_k } @vector $vec as score]"
120
+ )
121
+ .sort_by ("score" )
122
+ .return_fields ("id" , "score" )
123
+ .paging (0 , top_k if top_k > 0 else self .top_k )
124
+ .dialect (2 )
125
+ )
126
+ query_params = {"vec" : data .astype (np .float32 ).tobytes ()}
127
+ results = (
128
+ self ._client .ft (self .collection_name )
129
+ .search (query , query_params = query_params )
130
+ .docs
131
+ )
132
+ return [(float (result .score ), int (result .id [len (self .doc_prefix ):])) for result in results ]
133
+
134
+ def rebuild (self , ids = None ) -> bool :
135
+ pass
136
+
137
+ def delete (self , ids ) -> None :
138
+ pipe = self ._client .pipeline ()
139
+ for data_id in ids :
140
+ pipe .delete (f"{ self .doc_prefix } { data_id } " )
141
+ pipe .execute ()
0 commit comments