@@ -33,14 +33,16 @@ class DistanceOp(NamedTuple):
33
33
DISTANCE_OPS = {
34
34
"cosine" : DistanceOp ("vector_cosine_ops" , "<=>" , "1 - distance" ),
35
35
"l2" : DistanceOp ("vector_l2_ops" , "<->" , "distance * -1" ),
36
+ "halfvec_l2" : DistanceOp ("halfvec_l2_ops" , "<->" , "distance * -1" ),
36
37
"l1" : DistanceOp ("vector_l1_ops" , "<+>" , "distance * -1" ),
37
38
"ip" : DistanceOp ("vector_ip_ops" , "<#>" , "distance * -1" ),
38
39
"bit_hamming" : DistanceOp ("bit_hamming_ops" , "<~>" , "distance * -1" ),
39
40
"bit_jaccard" : DistanceOp ("bit_jaccard_ops" , "<%>" , "distance * -1" ),
40
41
"sparsevec_l2" : DistanceOp ("sparsevec_l2_ops" , "<->" , "distance * -1" ),
41
- "halfvec_l2" : DistanceOp ("halfvec_l2_ops" , "<->" , "distance * -1" ),
42
42
}
43
43
44
+ MAX_VECTOR_SIZE = 2000
45
+
44
46
45
47
class PgVectorStore (VectorStoreWithEmbedder [VectorStoreOptions ]):
46
48
"""
@@ -57,7 +59,8 @@ def __init__(
57
59
vector_size : int | None = None ,
58
60
embedding_type : EmbeddingType = EmbeddingType .TEXT ,
59
61
distance_method : str | None = None ,
60
- hnsw_params : dict | None = None ,
62
+ is_hnsw : bool = True ,
63
+ params : dict | None = None ,
61
64
default_options : VectorStoreOptions | None = None ,
62
65
) -> None :
63
66
"""
@@ -71,7 +74,8 @@ def __init__(
71
74
embedding_type: Which part of the entry to embed, either text or image. The other part will be ignored.
72
75
distance_method: The distance method to use, default is "cosine" for dense vectors
73
76
and "sparsevec_l2" for sparse vectors.
74
- hnsw_params: The parameters for the HNSW index. If None, the default parameters will be used.
77
+ is_hnsw: if hnsw or ivfflat indexing should be used
78
+ params: The parameters for the HNSW index. If None, the default parameters will be used.
75
79
default_options: The default options for querying the vector store.
76
80
"""
77
81
(
@@ -87,16 +91,22 @@ def __init__(
87
91
if vector_size is not None and (not isinstance (vector_size , int ) or vector_size <= 0 ):
88
92
raise ValueError ("Vector size must be a positive integer." )
89
93
90
- if hnsw_params is None :
91
- hnsw_params = {"m" : 4 , "ef_construction" : 10 }
92
- elif not isinstance (hnsw_params , dict ):
93
- raise ValueError ("hnsw_params must be a dictionary." )
94
- elif "m" not in hnsw_params or "ef_construction" not in hnsw_params :
95
- raise ValueError ("hnsw_params must contain 'm' and 'ef_construction' keys." )
96
- elif not isinstance (hnsw_params ["m" ], int ) or hnsw_params ["m" ] <= 0 :
97
- raise ValueError ("m must be a positive integer." )
98
- elif not isinstance (hnsw_params ["ef_construction" ], int ) or hnsw_params ["ef_construction" ] <= 0 :
99
- raise ValueError ("ef_construction must be a positive integer." )
94
+ if params is None and is_hnsw :
95
+ params = {"m" : 4 , "ef_construction" : 10 }
96
+ elif params is None and not is_hnsw :
97
+ params = {"lists" : 100 }
98
+ elif not isinstance (params , dict ):
99
+ raise ValueError ("params must be a dictionary." )
100
+ elif "m" not in params or "ef_construction" not in params and is_hnsw :
101
+ raise ValueError ("params must contain 'm' and 'ef_construction' keys for hnsw indexing." )
102
+ elif not isinstance (params ["m" ], int ) or params ["m" ] <= 0 and is_hnsw :
103
+ raise ValueError ("m must be a positive integer for hnsw indexing." )
104
+ elif not isinstance (params ["ef_construction" ], int ) or params ["ef_construction" ] <= 0 and is_hnsw :
105
+ raise ValueError ("ef_construction must be a positive integer for hnsw indexing." )
106
+ elif "lists" not in params and not is_hnsw :
107
+ raise ValueError ("params must contain 'lists' key for IVFFlat indexing." )
108
+ elif not isinstance (params ["lists" ], int ) or params ["lists" ] <= 0 and not is_hnsw :
109
+ raise ValueError ("lists must be a positive integer for IVFFlat indexing." )
100
110
101
111
if distance_method is None :
102
112
distance_method = "sparsevec_l2" if isinstance (embedder , SparseEmbedder ) else "cosine"
@@ -105,7 +115,7 @@ def __init__(
105
115
self ._vector_size = vector_size
106
116
self ._vector_size_info : VectorSize | None = None
107
117
self ._distance_method = distance_method
108
- self ._hnsw_params = hnsw_params
118
+ self ._indexing_params = params
109
119
110
120
def __reduce__ (self ) -> tuple :
111
121
"""
@@ -264,14 +274,15 @@ async def _check_table_exists(self) -> bool:
264
274
265
275
async def create_table (self ) -> None :
266
276
"""
267
- Create a pgVector table with an HNSW index for given similarity.
277
+ Create a pgVector table with an HNSW/IVFFlat index for given similarity.
268
278
"""
269
279
vector_size = await self ._get_vector_size ()
280
+
270
281
with trace (
271
282
table_name = self ._table_name ,
272
283
distance_method = self ._distance_method ,
273
284
vector_size = vector_size ,
274
- hnsw_index_parameters = self ._hnsw_params ,
285
+ hnsw_index_parameters = self ._indexing_params ,
275
286
):
276
287
distance = DISTANCE_OPS [self ._distance_method ].function_name
277
288
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
@@ -280,18 +291,38 @@ async def create_table(self) -> None:
280
291
# and it is a valid vector size.
281
292
282
293
is_sparse = isinstance (self ._embedder , SparseEmbedder )
283
- vector_func = "VECTOR" if not is_sparse else "SPARSEVEC"
294
+
295
+ # Check vector size
296
+ # if greater than 2000 then choose type HALFVEC
297
+ # More info: https://github.com/pgvector/pgvector
298
+ vector_func = (
299
+ "HALFVEC"
300
+ if vector_size > MAX_VECTOR_SIZE and re .search ("halfvec" , distance )
301
+ else "VECTOR"
302
+ if not is_sparse
303
+ else "SPARSEVEC"
304
+ )
284
305
285
306
create_table_query = f"""
286
307
CREATE TABLE { self ._table_name }
287
308
(id UUID, text TEXT, image_bytes BYTEA, vector { vector_func } ({ vector_size } ), metadata JSONB);
288
309
"""
289
- # _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
310
+ # _idexing_params has been validated in the class constructor, and it is valid dict[str,int].
311
+ if "lists" in self ._indexing_params :
312
+ index_type = "ivfflat"
313
+ index_params = f"(lists = { self ._indexing_params ['lists' ]} );"
314
+ else :
315
+ index_type = "hnsw"
316
+ index_params = (
317
+ f"(m = { self ._indexing_params ['m' ]} , ef_construction = { self ._indexing_params ['ef_construction' ]} );"
318
+ )
319
+
290
320
create_index_query = f"""
291
- CREATE INDEX { self ._table_name + "_hnsw_idx" } ON { self ._table_name }
292
- USING hnsw (vector { distance } )
293
- WITH (m = { self ._hnsw_params ["m" ]} , ef_construction = { self ._hnsw_params ["ef_construction" ]} );
294
- """
321
+ CREATE INDEX { self ._table_name + "_" + index_type + "_idx" } ON { self ._table_name }
322
+ USING { index_type } (vector { distance } )
323
+ WITH { index_params }
324
+ """
325
+
295
326
if await self ._check_table_exists ():
296
327
print (f"Table { self ._table_name } already exist!" )
297
328
return
0 commit comments