Skip to content

Commit 58d2ce7

Browse files
author
Nikos Papailiou
committed
WIP
1 parent de8faf6 commit 58d2ce7

File tree

4 files changed

+117
-7
lines changed

4 files changed

+117
-7
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import math
23

34
import numpy as np
45
from tiledb.vector_search.module import *
@@ -44,7 +45,7 @@ def query(
4445
nprobe: int = 1,
4546
):
4647
"""
47-
Open a flat index
48+
Query a flat index
4849
4950
Parameters
5051
----------
@@ -117,7 +118,7 @@ def query(
117118
use_nuv_implementation: bool = False,
118119
):
119120
"""
120-
Open a flat index
121+
Query an IVF_FLAT index
121122
122123
Parameters
123124
----------
@@ -128,7 +129,7 @@ def query(
128129
nqueries: int
129130
Number of queries
130131
nthreads: int
131-
Number of threads to use for queyr
132+
Number of threads to use for query
132133
nprobe: int
133134
number of probes
134135
use_nuv_implementation: bool
@@ -170,3 +171,66 @@ def query(
170171
)
171172

172173
return np.array(r)
174+
175+
def distributed_query(
176+
self,
177+
targets: np.ndarray,
178+
k=10,
179+
nthreads=8,
180+
nprobe=1,
181+
num_nodes=5,
182+
):
183+
"""
184+
Distributed Query on top of an IVF_FLAT index
185+
186+
Parameters
187+
----------
188+
targets: numpy.ndarray
189+
ND Array of query targets
190+
k: int
191+
Number of top results to return per target
192+
nqueries: int
193+
Number of queries
194+
nthreads: int
195+
Number of threads to use for query
196+
nprobe: int
197+
number of probes
198+
"""
199+
assert targets.dtype == np.float32
200+
201+
targets_m = array_to_matrix(targets)
202+
active_partitions, active_queries = partition_ivf_index(
203+
centroids=self._centroids,
204+
query=targets_m,
205+
nprobe=nprobe,
206+
nthreads=nthreads)
207+
num_parts = len(active_partitions)
208+
209+
parts_per_node = int(math.ceil(num_parts / num_nodes))
210+
results = []
211+
for part in range(0, num_parts, parts_per_node):
212+
part_end = part + parts_per_node
213+
if part_end > num_parts:
214+
part_end = num_parts
215+
results.append(dist_qv(
216+
dtype=self.dtype,
217+
parts_uri=self.parts_db_uri,
218+
ids_uri=self.ids_uri,
219+
query_vectors=targets_m,
220+
active_partitions=active_partitions[part:part_end],
221+
active_queries=active_queries[part:part_end],
222+
indices=self._index,
223+
k_nn=k,
224+
ctx=self.ctx,
225+
))
226+
227+
results_per_query = []
228+
for q in range(targets.shape[1]):
229+
tmp_results = []
230+
for j in range(k):
231+
for r in results:
232+
if len(r[q]) > 0:
233+
if r[q][j][0] > 0:
234+
tmp_results.append(r[q][j])
235+
results_per_query.append(sorted(tmp_results, key=lambda t: t[0])[0:k])
236+
return results_per_query

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,16 @@ static void declare_ivf_index_tdb(py::module& m, const std::string& suffix) {
268268
}, py::keep_alive<1,2>());
269269
}
270270

271+
template <class T=float, class U=size_t>
272+
static void declareFixedMinPairHeap(py::module& mod) {
273+
using PyFixedMinPairHeap = py::class_<fixed_min_pair_heap<T, U>>;
274+
PyFixedMinPairHeap cls(mod, "FixedMinPairHeap", py::buffer_protocol());
275+
276+
cls.def(py::init<unsigned>());
277+
cls.def("insert", &fixed_min_pair_heap<T, U>::insert);
278+
cls.def("__len__", [](const fixed_min_pair_heap<T, U> &v) { return v.size(); });
279+
cls.def("__getitem__", [](fixed_min_pair_heap<T, U>& v, size_t i) { return v[i]; });
280+
}
271281

272282
// Declarations for typed subclasses of ColMajorMatrix
273283
template <typename P>
@@ -346,14 +356,14 @@ void declarePartitionIvfIndex(py::module& m, const std::string& suffix) {
346356
);
347357
}
348358

349-
template <typename query_type, typename shuffled_ids_type = size_t>
359+
template <typename query_type, typename shuffled_ids_type = uint64_t>
350360
static void declare_dist_qv(py::module& m, const std::string& suffix) {
351361
m.def(("dist_qv_" + suffix).c_str(),
352362
[](tiledb::Context& ctx,
353363
const std::string& part_uri,
354-
std::vector<shuffled_ids_type>& active_partitions,
364+
std::vector<int>& active_partitions,
355365
ColMajorMatrix<query_type>& query,
356-
std::vector<std::vector<shuffled_ids_type>>& active_queries,
366+
std::vector<std::vector<int>>& active_queries,
357367
std::vector<shuffled_ids_type>& indices,
358368
const std::string& id_uri,
359369
size_t k_nn
@@ -508,5 +518,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
508518
declarePartitionedMatrix<tdbColMajorPartitionedMatrix<uint8_t, uint64_t, uint64_t, uint64_t > >(m, "tdbPartitionedMatrix", "u8");
509519
declarePartitionedMatrix<tdbColMajorPartitionedMatrix<float, uint64_t, uint64_t, uint64_t> >(m, "tdbPartitionedMatrix", "f32");
510520

521+
declare_dist_qv<uint8_t>(m, "u8");
511522
declare_dist_qv<float>(m, "f32");
523+
declareFixedMinPairHeap(m);
512524
}

apis/python/src/tiledb/vector_search/module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,36 @@ def partition_ivf_index(centroids, query, nprobe=1, nthreads=0):
302302
else:
303303
raise TypeError("Unsupported type!")
304304

305+
def dist_qv(
306+
dtype: np.dtype,
307+
parts_uri: str,
308+
ids_uri: str,
309+
query_vectors: "colMajorMatrix",
310+
active_partitions: "Vector",
311+
active_queries: "Vector",
312+
indices: "Vector",
313+
k_nn: int,
314+
ctx: "Ctx" = None):
315+
if ctx is None:
316+
ctx = Ctx({})
317+
args = tuple(
318+
[
319+
ctx,
320+
parts_uri,
321+
active_partitions,
322+
query_vectors,
323+
active_queries,
324+
indices,
325+
ids_uri,
326+
k_nn
327+
]
328+
)
329+
if dtype == np.float32:
330+
return dist_qv_f32(*args)
331+
elif dtype == np.uint8:
332+
return dist_qv_u8(*args)
333+
else:
334+
raise TypeError("Unsupported type!")
305335

306336
def validate_top_k(results: np.ndarray, ground_truth: np.ndarray):
307337
if results.dtype == np.uint64:

apis/python/test/test_ingestion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def test_flat_ingestion_f32(tmp_path):
4646
result = np.transpose(index.query(np.transpose(query_vectors), k=k))
4747
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
4848

49-
5049
def test_ivf_flat_ingestion_u8(tmp_path):
5150
dataset_dir = os.path.join(tmp_path, "dataset")
5251
array_uri = os.path.join(tmp_path, "array")
@@ -109,6 +108,10 @@ def test_ivf_flat_ingestion_f32(tmp_path):
109108
partitions=partitions,
110109
input_vectors_per_work_item=int(size / 10),
111110
)
111+
112+
result = index.distributed_query(np.transpose(query_vectors), k=k, nprobe=partitions)
113+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
114+
112115
result = np.transpose(
113116
index.query(np.transpose(query_vectors), k=k, nprobe=partitions)
114117
)
@@ -130,6 +133,7 @@ def test_ivf_flat_ingestion_f32(tmp_path):
130133
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
131134

132135

136+
133137
def test_ivf_flat_ingestion_fvec(tmp_path):
134138
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
135139
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"

0 commit comments

Comments
 (0)