Skip to content

Commit 1b89285

Browse files
feat: vectors longer than 2000 tokens supported (#777)
1 parent f7c920a commit 1b89285

File tree

4 files changed

+63
-24
lines changed

4 files changed

+63
-24
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,7 @@ qdrant/
105105

106106
.DS_Store
107107
node_modules/
108+
109+
lazygit
110+
111+
lazygit.tar.gz

packages/ragbits-core/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
## Unreleased
44

5+
6+
- Feat: added support for IVFFlat indexing and Halfvec datatype
7+
58
- Added Lazy loading of dependencies in local.py and during importing of LiteLLM
69
- Add tool_choice parameter to LLM interface (#738)
710
- Fix Prompt consumes same iterator twice leading to no data added to chat (#768)
811

12+
913
## 1.2.2 (2025-08-08)
1014

1115
- Fix: rendering iterator arguments in Prompt (#768)
@@ -233,4 +237,4 @@
233237
- LiteLLM integration.
234238
- ChromaDB integration.
235239
- Prompts lab.
236-
- Prompts autodiscovery.
240+
- Prompts autodiscovery.

packages/ragbits-core/src/ragbits/core/vector_stores/pgvector.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ class DistanceOp(NamedTuple):
3333
DISTANCE_OPS = {
3434
"cosine": DistanceOp("vector_cosine_ops", "<=>", "1 - distance"),
3535
"l2": DistanceOp("vector_l2_ops", "<->", "distance * -1"),
36+
"halfvec_l2": DistanceOp("halfvec_l2_ops", "<->", "distance * -1"),
3637
"l1": DistanceOp("vector_l1_ops", "<+>", "distance * -1"),
3738
"ip": DistanceOp("vector_ip_ops", "<#>", "distance * -1"),
3839
"bit_hamming": DistanceOp("bit_hamming_ops", "<~>", "distance * -1"),
3940
"bit_jaccard": DistanceOp("bit_jaccard_ops", "<%>", "distance * -1"),
4041
"sparsevec_l2": DistanceOp("sparsevec_l2_ops", "<->", "distance * -1"),
41-
"halfvec_l2": DistanceOp("halfvec_l2_ops", "<->", "distance * -1"),
4242
}
4343

44+
MAX_VECTOR_SIZE = 2000
45+
4446

4547
class PgVectorStore(VectorStoreWithEmbedder[VectorStoreOptions]):
4648
"""
@@ -57,7 +59,8 @@ def __init__(
5759
vector_size: int | None = None,
5860
embedding_type: EmbeddingType = EmbeddingType.TEXT,
5961
distance_method: str | None = None,
60-
hnsw_params: dict | None = None,
62+
is_hnsw: bool = True,
63+
params: dict | None = None,
6164
default_options: VectorStoreOptions | None = None,
6265
) -> None:
6366
"""
@@ -71,7 +74,8 @@ def __init__(
7174
embedding_type: Which part of the entry to embed, either text or image. The other part will be ignored.
7275
distance_method: The distance method to use, default is "cosine" for dense vectors
7376
and "sparsevec_l2" for sparse vectors.
74-
hnsw_params: The parameters for the HNSW index. If None, the default parameters will be used.
77+
is_hnsw: if hnsw or ivfflat indexing should be used
78+
params: The parameters for the HNSW index. If None, the default parameters will be used.
7579
default_options: The default options for querying the vector store.
7680
"""
7781
(
@@ -87,16 +91,22 @@ def __init__(
8791
if vector_size is not None and (not isinstance(vector_size, int) or vector_size <= 0):
8892
raise ValueError("Vector size must be a positive integer.")
8993

90-
if hnsw_params is None:
91-
hnsw_params = {"m": 4, "ef_construction": 10}
92-
elif not isinstance(hnsw_params, dict):
93-
raise ValueError("hnsw_params must be a dictionary.")
94-
elif "m" not in hnsw_params or "ef_construction" not in hnsw_params:
95-
raise ValueError("hnsw_params must contain 'm' and 'ef_construction' keys.")
96-
elif not isinstance(hnsw_params["m"], int) or hnsw_params["m"] <= 0:
97-
raise ValueError("m must be a positive integer.")
98-
elif not isinstance(hnsw_params["ef_construction"], int) or hnsw_params["ef_construction"] <= 0:
99-
raise ValueError("ef_construction must be a positive integer.")
94+
if params is None and is_hnsw:
95+
params = {"m": 4, "ef_construction": 10}
96+
elif params is None and not is_hnsw:
97+
params = {"lists": 100}
98+
elif not isinstance(params, dict):
99+
raise ValueError("params must be a dictionary.")
100+
elif "m" not in params or "ef_construction" not in params and is_hnsw:
101+
raise ValueError("params must contain 'm' and 'ef_construction' keys for hnsw indexing.")
102+
elif not isinstance(params["m"], int) or params["m"] <= 0 and is_hnsw:
103+
raise ValueError("m must be a positive integer for hnsw indexing.")
104+
elif not isinstance(params["ef_construction"], int) or params["ef_construction"] <= 0 and is_hnsw:
105+
raise ValueError("ef_construction must be a positive integer for hnsw indexing.")
106+
elif "lists" not in params and not is_hnsw:
107+
raise ValueError("params must contain 'lists' key for IVFFlat indexing.")
108+
elif not isinstance(params["lists"], int) or params["lists"] <= 0 and not is_hnsw:
109+
raise ValueError("lists must be a positive integer for IVFFlat indexing.")
100110

101111
if distance_method is None:
102112
distance_method = "sparsevec_l2" if isinstance(embedder, SparseEmbedder) else "cosine"
@@ -105,7 +115,7 @@ def __init__(
105115
self._vector_size = vector_size
106116
self._vector_size_info: VectorSize | None = None
107117
self._distance_method = distance_method
108-
self._hnsw_params = hnsw_params
118+
self._indexing_params = params
109119

110120
def __reduce__(self) -> tuple:
111121
"""
@@ -264,14 +274,15 @@ async def _check_table_exists(self) -> bool:
264274

265275
async def create_table(self) -> None:
266276
"""
267-
Create a pgVector table with an HNSW index for given similarity.
277+
Create a pgVector table with an HNSW/IVFFlat index for given similarity.
268278
"""
269279
vector_size = await self._get_vector_size()
280+
270281
with trace(
271282
table_name=self._table_name,
272283
distance_method=self._distance_method,
273284
vector_size=vector_size,
274-
hnsw_index_parameters=self._hnsw_params,
285+
hnsw_index_parameters=self._indexing_params,
275286
):
276287
distance = DISTANCE_OPS[self._distance_method].function_name
277288
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
@@ -280,18 +291,38 @@ async def create_table(self) -> None:
280291
# and it is a valid vector size.
281292

282293
is_sparse = isinstance(self._embedder, SparseEmbedder)
283-
vector_func = "VECTOR" if not is_sparse else "SPARSEVEC"
294+
295+
# Check vector size
296+
# if greater than 2000 then choose type HALFVEC
297+
# More info: https://github.com/pgvector/pgvector
298+
vector_func = (
299+
"HALFVEC"
300+
if vector_size > MAX_VECTOR_SIZE and re.search("halfvec", distance)
301+
else "VECTOR"
302+
if not is_sparse
303+
else "SPARSEVEC"
304+
)
284305

285306
create_table_query = f"""
286307
CREATE TABLE {self._table_name}
287308
(id UUID, text TEXT, image_bytes BYTEA, vector {vector_func}({vector_size}), metadata JSONB);
288309
"""
289-
# _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
310+
# _idexing_params has been validated in the class constructor, and it is valid dict[str,int].
311+
if "lists" in self._indexing_params:
312+
index_type = "ivfflat"
313+
index_params = f"(lists = {self._indexing_params['lists']});"
314+
else:
315+
index_type = "hnsw"
316+
index_params = (
317+
f"(m = {self._indexing_params['m']}, ef_construction = {self._indexing_params['ef_construction']});"
318+
)
319+
290320
create_index_query = f"""
291-
CREATE INDEX {self._table_name + "_hnsw_idx"} ON {self._table_name}
292-
USING hnsw (vector {distance})
293-
WITH (m = {self._hnsw_params["m"]}, ef_construction = {self._hnsw_params["ef_construction"]});
294-
"""
321+
CREATE INDEX {self._table_name + "_" + index_type + "_idx"} ON {self._table_name}
322+
USING {index_type} (vector {distance})
323+
WITH {index_params}
324+
"""
325+
295326
if await self._check_table_exists():
296327
print(f"Table {self._table_name} already exist!")
297328
return

packages/ragbits-core/tests/unit/vector_stores/test_pgvector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def test_invalid_hnsw_raises_error(mock_db_pool: tuple[MagicMock, AsyncMoc
8282
client=mock_pool,
8383
table_name=TEST_TABLE_NAME,
8484
vector_size=3,
85-
hnsw_params=hnsw, # type: ignore
85+
params=hnsw, # type: ignore
8686
embedder=NoopEmbedder(),
8787
)
8888

0 commit comments

Comments
 (0)