Skip to content

Commit 35acf35

Browse files
authored
Update IVF_PQ to set memory_budget in constructor, support preload feature_vectors and metadata only modes (#518)
1 parent b88d4ba commit 35acf35

File tree

12 files changed

+620
-589
lines changed

12 files changed

+620
-589
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,20 @@ def _query_with_driver(
192192
def query_udf(index_type, index_open_kwargs, query_kwargs):
193193
from tiledb.vector_search.flat_index import FlatIndex
194194
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
195+
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
195196
from tiledb.vector_search.vamana_index import VamanaIndex
196197

197198
# Open index
198199
if index_type == "FLAT":
199200
index = FlatIndex(**index_open_kwargs)
200201
elif index_type == "IVF_FLAT":
201202
index = IVFFlatIndex(**index_open_kwargs)
203+
elif index_type == "IVF_PQ":
204+
index = IVFPQIndex(**index_open_kwargs)
202205
elif index_type == "VAMANA":
203206
index = VamanaIndex(**index_open_kwargs)
207+
else:
208+
raise ValueError(f"Unsupported index_type: {index_type}")
204209

205210
# Query index
206211
return index.query(**query_kwargs)

apis/python/src/tiledb/vector_search/ivf_pq_index.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class IVFPQIndex(index.Index):
3737
If not provided, all index data are loaded in main memory.
3838
Otherwise, no index data are loaded in main memory and this memory budget is
3939
applied during queries.
40+
preload_k_factor_vectors: bool
41+
When using `k_factor` in a query, we first query for `k_factor * k` pq-encoded vectors,
42+
and then do a re-ranking step using the original input vectors for the top `k` vectors.
43+
If `True`, we will load all the input vectors in main memory. This can only be used with
44+
`memory_budget` set to `-1`, and is useful when the input vectors are small enough to fit in
45+
memory and you want to speed up re-ranking.
4046
open_for_remote_query_execution: bool
4147
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
4248
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
@@ -48,15 +54,26 @@ def __init__(
4854
config: Optional[Mapping[str, Any]] = None,
4955
timestamp=None,
5056
memory_budget: int = -1,
57+
preload_k_factor_vectors: bool = False,
5158
open_for_remote_query_execution: bool = False,
5259
group: tiledb.Group = None,
5360
**kwargs,
5461
):
62+
if preload_k_factor_vectors and memory_budget != -1:
63+
raise ValueError(
64+
"preload_k_factor_vectors can only be used with memory_budget set to -1."
65+
)
66+
if preload_k_factor_vectors and open_for_remote_query_execution:
67+
raise ValueError(
68+
"preload_k_factor_vectors can only be used with open_for_remote_query_execution set to False."
69+
)
70+
5571
self.index_open_kwargs = {
5672
"uri": uri,
5773
"config": config,
5874
"timestamp": timestamp,
5975
"memory_budget": memory_budget,
76+
"preload_k_factor_vectors": preload_k_factor_vectors,
6077
}
6178
self.index_open_kwargs.update(kwargs)
6279
self.index_type = INDEX_TYPE
@@ -67,8 +84,21 @@ def __init__(
6784
open_for_remote_query_execution=open_for_remote_query_execution,
6885
group=group,
6986
)
70-
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
71-
self.index = vspy.IndexIVFPQ(self.ctx, uri, to_temporal_policy(timestamp))
87+
strategy = (
88+
vspy.IndexLoadStrategy.PQ_INDEX_AND_RERANKING_VECTORS
89+
if preload_k_factor_vectors
90+
else vspy.IndexLoadStrategy.PQ_OOC
91+
if open_for_remote_query_execution
92+
or (memory_budget != -1 and memory_budget != 0)
93+
else vspy.IndexLoadStrategy.PQ_INDEX
94+
)
95+
self.index = vspy.IndexIVFPQ(
96+
self.ctx,
97+
uri,
98+
strategy,
99+
0 if memory_budget == -1 else memory_budget,
100+
to_temporal_policy(timestamp),
101+
)
72102
self.db_uri = self.group[
73103
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
74104
].uri
@@ -127,16 +157,9 @@ def query_internal(
127157
if not queries.flags.f_contiguous:
128158
queries = queries.copy(order="F")
129159
queries_feature_vector_array = vspy.FeatureVectorArray(queries)
130-
131-
if self.memory_budget == -1:
132-
distances, ids = self.index.query_infinite_ram(
133-
queries_feature_vector_array, k, nprobe, k_factor
134-
)
135-
else:
136-
distances, ids = self.index.query_finite_ram(
137-
queries_feature_vector_array, k, nprobe, self.memory_budget, k_factor
138-
)
139-
160+
distances, ids = self.index.query(
161+
queries_feature_vector_array, k=k, nprobe=nprobe, k_factor=k_factor
162+
)
140163
return np.array(distances, copy=False), np.array(ids, copy=False)
141164

142165

@@ -203,7 +226,7 @@ def create(
203226
id_type=np.dtype(np.uint64).name,
204227
partitioning_index_type=np.dtype(np.uint64).name,
205228
dimensions=dimensions,
206-
n_list=partitions if (partitions is not None and partitions is not -1) else 0,
229+
n_list=partitions if (partitions is not None and partitions != -1) else 0,
207230
num_subspaces=num_subspaces,
208231
distance_metric=int(distance_metric),
209232
)

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "detail/linalg/tdb_matrix.h"
1919
#include "detail/linalg/tdb_partitioned_matrix.h"
2020
#include "detail/time/temporal_policy.h"
21+
#include "index/index_defs.h"
2122
#include "utils/seeder.h"
2223

2324
namespace py = pybind11;
@@ -1096,6 +1097,14 @@ PYBIND11_MODULE(_tiledbvspy, m) {
10961097
.value("L2", DistanceMetric::L2)
10971098
.export_values();
10981099

1100+
py::enum_<IndexLoadStrategy>(m, "IndexLoadStrategy")
1101+
.value("PQ_OOC", IndexLoadStrategy::PQ_OOC)
1102+
.value("PQ_INDEX", IndexLoadStrategy::PQ_INDEX)
1103+
.value(
1104+
"PQ_INDEX_AND_RERANKING_VECTORS",
1105+
IndexLoadStrategy::PQ_INDEX_AND_RERANKING_VECTORS)
1106+
.export_values();
1107+
10991108
/* === Module inits === */
11001109

11011110
init_kmeans(m);

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,8 @@ void init_type_erased_module(py::module_& m) {
368368
.def("dimensions", &IndexFlatL2::dimensions)
369369
.def(
370370
"query",
371-
[](IndexFlatL2& index,
372-
const FeatureVectorArray& vectors,
373-
size_t top_k) {
374-
auto r = index.query(vectors, top_k);
371+
[](IndexFlatL2& index, const FeatureVectorArray& vectors, size_t k) {
372+
auto r = index.query(vectors, k);
375373
return make_python_pair(std::move(r));
376374
});
377375

@@ -422,13 +420,13 @@ void init_type_erased_module(py::module_& m) {
422420
"query",
423421
[](IndexVamana& index,
424422
const FeatureVectorArray& vectors,
425-
size_t top_k,
423+
size_t k,
426424
uint32_t l_search) {
427-
auto r = index.query(vectors, top_k, l_search);
425+
auto r = index.query(vectors, k, l_search);
428426
return make_python_pair(std::move(r));
429427
},
430428
py::arg("vectors"),
431-
py::arg("top_k"),
429+
py::arg("k"),
432430
py::arg("l_search"))
433431
.def(
434432
"write_index",
@@ -467,12 +465,21 @@ void init_type_erased_module(py::module_& m) {
467465
[](IndexIVFPQ& instance,
468466
const tiledb::Context& ctx,
469467
const std::string& group_uri,
468+
IndexLoadStrategy index_load_strategy,
469+
size_t memory_budget,
470470
std::optional<TemporalPolicy> temporal_policy) {
471-
new (&instance) IndexIVFPQ(ctx, group_uri, temporal_policy);
471+
new (&instance) IndexIVFPQ(
472+
ctx,
473+
group_uri,
474+
index_load_strategy,
475+
memory_budget,
476+
temporal_policy);
472477
},
473478
py::keep_alive<1, 2>(), // IndexIVFPQ should keep ctx alive.
474479
py::arg("ctx"),
475480
py::arg("group_uri"),
481+
py::arg("index_load_strategy") = IndexLoadStrategy::PQ_INDEX,
482+
py::arg("memory_budget") = 0,
476483
py::arg("temporal_policy") = std::nullopt)
477484
.def(
478485
"__init__",
@@ -494,41 +501,18 @@ void init_type_erased_module(py::module_& m) {
494501
},
495502
py::arg("vectors"))
496503
.def(
497-
"query_infinite_ram",
498-
[](IndexIVFPQ& index,
499-
const FeatureVectorArray& vectors,
500-
size_t top_k,
501-
size_t nprobe,
502-
float k_factor) {
503-
auto r = index.query(
504-
QueryType::InfiniteRAM, vectors, top_k, nprobe, 0, k_factor);
505-
return make_python_pair(std::move(r));
506-
},
507-
py::arg("vectors"),
508-
py::arg("top_k"),
509-
py::arg("nprobe"),
510-
py::arg("k_factor") = 1.f)
511-
.def(
512-
"query_finite_ram",
504+
"query",
513505
[](IndexIVFPQ& index,
514506
const FeatureVectorArray& vectors,
515-
size_t top_k,
507+
size_t k,
516508
size_t nprobe,
517-
size_t memory_budget,
518509
float k_factor) {
519-
auto r = index.query(
520-
QueryType::FiniteRAM,
521-
vectors,
522-
top_k,
523-
nprobe,
524-
memory_budget,
525-
k_factor);
510+
auto r = index.query(vectors, k, nprobe, k_factor);
526511
return make_python_pair(std::move(r));
527512
},
528513
py::arg("vectors"),
529-
py::arg("top_k"),
514+
py::arg("k"),
530515
py::arg("nprobe"),
531-
py::arg("memory_budget"),
532516
py::arg("k_factor") = 1.f)
533517
.def(
534518
"write_index",
@@ -603,24 +587,24 @@ void init_type_erased_module(py::module_& m) {
603587
"query_infinite_ram",
604588
[](IndexIVFFlat& index,
605589
const FeatureVectorArray& query,
606-
size_t top_k,
590+
size_t k,
607591
size_t nprobe) {
608-
auto r = index.query_infinite_ram(query, top_k, nprobe);
592+
auto r = index.query_infinite_ram(query, k, nprobe);
609593
return make_python_pair(std::move(r));
610-
}) // , py::arg("vectors"), py::arg("top_k") = 1, py::arg("nprobe")
594+
}) // , py::arg("vectors"), py::arg("k") = 1, py::arg("nprobe")
611595
// = 10)
612596
.def(
613597
"query_finite_ram",
614598
[](IndexIVFFlat& index,
615599
const FeatureVectorArray& query,
616-
size_t top_k,
600+
size_t k,
617601
size_t nprobe,
618602
size_t upper_bound) {
619-
auto r = index.query_finite_ram(query, top_k, nprobe, upper_bound);
603+
auto r = index.query_finite_ram(query, k, nprobe, upper_bound);
620604
return make_python_pair(std::move(r));
621605
},
622606
py::arg("vectors"),
623-
py::arg("top_k") = 1,
607+
py::arg("k") = 1,
624608
py::arg("nprobe") = 10,
625609
py::arg("upper_bound") = 0)
626610
.def("feature_type_string", &IndexIVFFlat::feature_type_string)

apis/python/test/local-benchmarks.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from tiledb.vector_search.index import Index
2323
from tiledb.vector_search.ingestion import TrainingSamplingPolicy
2424
from tiledb.vector_search.ingestion import ingest
25+
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
26+
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
2527
from tiledb.vector_search.utils import load_fvecs
28+
from tiledb.vector_search.vamana_index import VamanaIndex
2629

2730

2831
class RemoteURIType(Enum):
@@ -252,7 +255,7 @@ def save_charts(self):
252255
plt.xlabel("Average Query Accuracy")
253256
plt.ylabel("Time (seconds)")
254257
plt.title(f"Ingestion Time vs Average Query Accuracy {sift_string()}")
255-
for idx, timer in self.timers:
258+
for idx, timer in enumerate(self.timers):
256259
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
257260
markers[idx % len(markers)]
258261
)
@@ -265,7 +268,7 @@ def save_charts(self):
265268
plt.xlabel("Accuracy")
266269
plt.ylabel("Time (seconds)")
267270
plt.title(f"Query Time vs Accuracy {sift_string()}")
268-
for idx, timer in self.timers:
271+
for idx, timer in enumerate(self.timers):
269272
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
270273
plt.legend()
271274
plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png"))
@@ -295,6 +298,7 @@ def download_and_extract(url, download_path, extract_path):
295298

296299

297300
def get_uri(tag):
301+
global config
298302
index_name = f"index_{tag.replace('=', '_')}"
299303
index_uri = ""
300304
if REMOTE_URI_TYPE == RemoteURIType.LOCAL:
@@ -346,7 +350,7 @@ def benchmark_ivf_flat():
346350
index_uri = get_uri(tag)
347351

348352
timer.start(tag, TimerMode.INGESTION)
349-
index = ingest(
353+
ingest(
350354
index_type=index_type,
351355
index_uri=index_uri,
352356
source_uri=SIFT_BASE_PATH,
@@ -356,6 +360,10 @@ def benchmark_ivf_flat():
356360
)
357361
ingest_time = timer.stop(tag, TimerMode.INGESTION)
358362

363+
# The index returned by ingest() automatically has memory_budget=1000000 set. Open
364+
# a fresh index so it's clear what config is being used.
365+
index = IVFFlatIndex(index_uri, config)
366+
359367
for nprobe in [1, 2, 3, 4, 5, 10, 20]:
360368
timer.start(tag, TimerMode.QUERY)
361369
_, result = index.query(queries, k=k, nprobe=nprobe)
@@ -386,7 +394,7 @@ def benchmark_vamana():
386394
index_uri = get_uri(tag)
387395

388396
timer.start(tag, TimerMode.INGESTION)
389-
index = ingest(
397+
ingest(
390398
index_type=index_type,
391399
index_uri=index_uri,
392400
source_uri=SIFT_BASE_PATH,
@@ -397,6 +405,8 @@ def benchmark_vamana():
397405
)
398406
ingest_time = timer.stop(tag, TimerMode.INGESTION)
399407

408+
index = VamanaIndex(index_uri, config)
409+
400410
for l_search in [k, k + 50, k + 100, k + 200, k + 400]:
401411
timer.start(tag, TimerMode.QUERY)
402412
_, result = index.query(queries, k=k, l_search=l_search)
@@ -429,7 +439,7 @@ def benchmark_ivf_pq():
429439
index_uri = get_uri(tag)
430440

431441
timer.start(tag, TimerMode.INGESTION)
432-
index = ingest(
442+
ingest(
433443
index_type=index_type,
434444
index_uri=index_uri,
435445
source_uri=SIFT_BASE_PATH,
@@ -440,6 +450,10 @@ def benchmark_ivf_pq():
440450
)
441451
ingest_time = timer.stop(tag, TimerMode.INGESTION)
442452

453+
# The index returned by ingest() automatically has memory_budget=1000000 set. Open
454+
# a fresh index so it's clear what config is being used.
455+
index = IVFPQIndex(index_uri, config)
456+
443457
for nprobe in [5, 10, 20, 40, 60]:
444458
timer.start(tag, TimerMode.QUERY)
445459
_, result = index.query(

apis/python/test/test_cloud.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def run_cloud_test(self, index_uri, index_type, index_class):
6363
input_vectors_per_work_item=5000,
6464
config=tiledb.cloud.Config().dict(),
6565
mode=Mode.BATCH,
66+
verbose=True,
6667
)
6768
tiledb_index_uri = groups.info(index_uri).tiledb_uri
6869

0 commit comments

Comments
 (0)