Skip to content

Commit 14fd1a8

Browse files
Configure memory budget for distributed OOC queries (#462)
Configure memory budget for distributed OOC queries. This is a no-op until the library change gets released in TileDB cloud.
1 parent e1b0526 commit 14fd1a8

File tree

3 files changed

+26
-31
lines changed

3 files changed

+26
-31
lines changed

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -371,33 +371,23 @@ def dist_qv_udf(
371371
k_nn: int,
372372
config: Optional[Mapping[str, Any]] = None,
373373
timestamp: int = 0,
374+
memory_budget: int = -1,
374375
):
375376
queries_m = array_to_matrix(np.transpose(query_vectors))
376-
if timestamp == 0:
377-
r = dist_qv(
378-
dtype=dtype,
379-
parts_uri=parts_uri,
380-
ids_uri=ids_uri,
381-
query_vectors=queries_m,
382-
active_partitions=active_partitions,
383-
active_queries=active_queries,
384-
indices=indices,
385-
k_nn=k_nn,
386-
ctx=Ctx(config),
387-
)
388-
else:
389-
r = dist_qv(
390-
dtype=dtype,
391-
parts_uri=parts_uri,
392-
ids_uri=ids_uri,
393-
query_vectors=queries_m,
394-
active_partitions=active_partitions,
395-
active_queries=active_queries,
396-
indices=indices,
397-
k_nn=k_nn,
398-
ctx=Ctx(config),
399-
timestamp=timestamp,
400-
)
377+
r = dist_qv(
378+
dtype=dtype,
379+
parts_uri=parts_uri,
380+
ids_uri=ids_uri,
381+
query_vectors=queries_m,
382+
active_partitions=active_partitions,
383+
active_queries=active_queries,
384+
indices=indices,
385+
k_nn=k_nn,
386+
ctx=Ctx(config),
387+
timestamp=timestamp,
388+
# TODO(nikos) add this after the library change gets released in TileDB cloud
389+
# upper_bound=0 if memory_budget == -1 else memory_budget,
390+
)
401391
results = []
402392
for q in range(len(r)):
403393
tmp_results = []
@@ -464,6 +454,7 @@ def dist_qv_udf(
464454
k_nn=k,
465455
config=config,
466456
timestamp=self.base_array_timestamp,
457+
memory_budget=self.memory_budget,
467458
resource_class="large"
468459
if (not resources and not resource_class)
469460
else resource_class,

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,10 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
496496
std::vector<indices_type>& indices, // 5
497497
const std::string& id_uri,
498498
size_t k_nn,
499-
uint64_t timestamp
499+
uint64_t timestamp,
500+
size_t upper_bound
500501
/* size_t nthreads TODO: optional arg w/ fallback to C++ default arg */
501502
) { /* TODO return type */
502-
size_t upper_bound{0};
503503
auto nthreads = std::thread::hardware_concurrency();
504504

505505
return detail::ivf::dist_qv_finite_ram_part<T, shuffled_ids_type>(
@@ -511,7 +511,8 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
511511
indices,
512512
id_uri,
513513
k_nn,
514-
timestamp);
514+
timestamp,
515+
upper_bound);
515516
},
516517
py::keep_alive<1, 2>());
517518
m.def(
@@ -525,11 +526,11 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
525526
std::vector<shuffled_ids_type>& indices,
526527
const std::string& id_uri,
527528
size_t k_nn,
528-
uint64_t timestamp
529+
uint64_t timestamp,
530+
size_t upper_bound
529531
/* size_t nthreads @todo: optional arg w/ fallback to C++ default arg
530532
*/
531533
) { /* @todo: return type */
532-
size_t upper_bound{0};
533534
auto nthreads = std::thread::hardware_concurrency();
534535
auto temporal_policy{
535536
(timestamp == 0) ? TemporalPolicy() :
@@ -564,7 +565,8 @@ static void declare_dist_qv(py::module& m, const std::string& suffix) {
564565
indices,
565566
id_uri,
566567
k_nn,
567-
timestamp);
568+
timestamp,
569+
upper_bound);
568570
},
569571
py::keep_alive<1, 2>());
570572
}

apis/python/src/tiledb/vector_search/module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def dist_qv(
448448
k_nn: int,
449449
ctx: "Ctx" = None,
450450
timestamp: int = 0,
451+
upper_bound: int = 0,
451452
):
452453
if ctx is None:
453454
ctx = vspy.Ctx({})
@@ -462,6 +463,7 @@ def dist_qv(
462463
ids_uri,
463464
k_nn,
464465
timestamp,
466+
upper_bound,
465467
]
466468
)
467469

0 commit comments

Comments
 (0)