Skip to content

Commit c8e68e6

Browse files
lums658ihnorton
authored andcommitted
Return distances from query functions
This PR updates the IVF queries to return top k distances in addition to top k neighbor indices. All of our search functions now use a fixed size min heap of to keep distances and indices. The heap is templated on the type of distance and type of index, which are kept as a tuple in the heap. The heap is sorted on the first element (the score). To create a matrix of top k neighbors, the search functions sorted the heap and then copied the neighbor indices into a matrix. With the new API, queries can now be called as ``` auto&& [ D, I ] = query(...) ``` To update the query functions to return distances in addition to indices, this PR * Created a new file scoring.h which includes most of the functionality previously in defs.h * Deleted defs.h * Created a function `consolidate` which consolidates distance and index information from a vector of vector of indices into a single vector (the 0th vector in the vector of vectors) * Factored out the code to copy the index information in the heap structure to a matrix into new functions `get_top_k_from_heap` and two overloads of `get_top_k` * Created augmented functions `get_top_k_from_heap_with_scores` and two overloads of `get_top_k_with_scores` that return a tuple of a distance matrix and an index matrix. * Replaced the final logic in all of the ivf queries with `consolidate` + `get_top_k_with_scores` (not all functions needed `consolidate`) * Added extensive new tests in unit_fixed_min_heap.cc, unit_linalg.cc, and unit_scoring.cc, notably to validate the new `get_top_k` family of functions. * Updated the ivf_flat C++ CLI program to use the new query API * Updated the Python bindings in module.cc to use the new query API In addition * Propagated the BLAS macro to exclude BLAS code from the library so that the CLIs can be built without BLAS * Added docstrings to the query functions in ivf/qv.h * Added an initializer list constructor to `Matrix` to aid in creating tests
1 parent 8e4f847 commit c8e68e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+4597
-1574
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def query(
100100
assert targets.dtype == np.float32
101101

102102
targets_m = array_to_matrix(np.transpose(targets))
103-
r = query_vq_heap(self._db, targets_m, self._ids, k, nthreads)
103+
_, r = query_vq_heap(self._db, targets_m, self._ids, k, nthreads)
104104

105-
return np.transpose(np.array(r))
105+
return np.transpose(np.array(_)), np.transpose(np.array(r))
106106

107107

108108
class IVFFlatIndex(Index):
@@ -222,7 +222,7 @@ def query(
222222
if mode is None:
223223
queries_m = array_to_matrix(np.transpose(queries))
224224
if self.memory_budget == -1:
225-
r = ivf_query_ram(
225+
_, r = ivf_query_ram(
226226
self.dtype,
227227
self._db,
228228
self._centroids,
@@ -231,13 +231,12 @@ def query(
231231
self._ids,
232232
nprobe=nprobe,
233233
k_nn=k,
234-
nth=True, # ??
235234
nthreads=nthreads,
236235
ctx=self.ctx,
237236
use_nuv_implementation=use_nuv_implementation,
238237
)
239238
else:
240-
r = ivf_query(
239+
_, r = ivf_query(
241240
self.dtype,
242241
self.parts_db_uri,
243242
self._centroids,
@@ -247,13 +246,12 @@ def query(
247246
nprobe=nprobe,
248247
k_nn=k,
249248
memory_budget=self.memory_budget,
250-
nth=True, # ??
251249
nthreads=nthreads,
252250
ctx=self.ctx,
253251
use_nuv_implementation=use_nuv_implementation,
254252
)
255253

256-
return np.transpose(np.array(r))
254+
return np.transpose(np.array(_)), np.transpose(np.array(r))
257255
else:
258256
return self.taskgraph_query(
259257
queries=queries,

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
#include <pybind11/stl.h>
66

77
#include "linalg.h"
8-
#include "ivf_index.h"
98
#include "ivf_query.h"
109
#include "flat_query.h"
1110

1211
namespace py = pybind11;
1312
using Ctx = tiledb::Context;
1413

1514
bool global_debug = false;
16-
double global_time_of_interest;
1715

1816
bool enable_stats = false;
1917
std::vector<json> core_stats;
@@ -113,8 +111,7 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
113111
std::vector<Id_Type>& ids,
114112
size_t nprobe,
115113
size_t k_nn,
116-
bool nth,
117-
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
114+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
118115

119116
auto r = detail::ivf::qv_query_heap_infinite_ram(
120117
parts,
@@ -124,7 +121,6 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
124121
ids,
125122
nprobe,
126123
k_nn,
127-
nth,
128124
nthreads);
129125
return r;
130126
}, py::keep_alive<1,2>());
@@ -142,8 +138,7 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
142138
size_t nprobe,
143139
size_t k_nn,
144140
size_t upper_bound,
145-
bool nth,
146-
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
141+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
147142

148143
auto r = detail::ivf::qv_query_heap_finite_ram<T, Id_Type>(
149144
ctx,
@@ -155,7 +150,6 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
155150
nprobe,
156151
k_nn,
157152
upper_bound,
158-
nth,
159153
nthreads);
160154
return r;
161155
}, py::keep_alive<1,2>());
@@ -171,8 +165,7 @@ static void declare_nuv_query_heap_infinite_ram(py::module& m, const std::string
171165
std::vector<Id_Type>& ids,
172166
size_t nprobe,
173167
size_t k_nn,
174-
bool nth,
175-
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
168+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
176169

177170
auto r = detail::ivf::nuv_query_heap_infinite_ram_reg_blocked(
178171
parts,
@@ -182,7 +175,6 @@ static void declare_nuv_query_heap_infinite_ram(py::module& m, const std::string
182175
ids,
183176
nprobe,
184177
k_nn,
185-
nth,
186178
nthreads);
187179
return r;
188180
}, py::keep_alive<1,2>());
@@ -200,8 +192,7 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
200192
size_t nprobe,
201193
size_t k_nn,
202194
size_t upper_bound,
203-
bool nth,
204-
size_t nthreads) -> ColMajorMatrix<size_t> { // TODO change return type
195+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
205196

206197
auto r = detail::ivf::nuv_query_heap_finite_ram_reg_blocked<T, Id_Type>(
207198
ctx,
@@ -213,7 +204,6 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
213204
nprobe,
214205
k_nn,
215206
upper_bound,
216-
nth,
217207
nthreads);
218208
return r;
219209
}, py::keep_alive<1,2>());
@@ -394,7 +384,7 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
394384
ColMajorMatrix<float>& query_vectors,
395385
const std::vector<uint64_t> &ids,
396386
int k,
397-
size_t nthreads) -> ColMajorMatrix<size_t> {
387+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
398388
auto r = detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
399389
return r;
400390
});
@@ -477,17 +467,17 @@ PYBIND11_MODULE(_tiledbvspy, m) {
477467
[](ColMajorMatrix<float>& data,
478468
ColMajorMatrix<float>& query_vectors,
479469
int k,
480-
size_t nthreads) -> ColMajorMatrix<size_t> {
481-
auto r = detail::flat::vq_query_nth(data, query_vectors, k, true, nthreads);
470+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
471+
auto r = detail::flat::vq_query_heap(data, query_vectors, k, nthreads);
482472
return r;
483473
});
484474

485475
m.def("query_vq_u8",
486476
[](tdbColMajorMatrix<uint8_t>& data,
487477
ColMajorMatrix<float>& query_vectors,
488478
int k,
489-
size_t nthreads) -> ColMajorMatrix<size_t> {
490-
auto r = detail::flat::vq_query_nth(data, query_vectors, k, true, nthreads);
479+
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
480+
auto r = detail::flat::vq_query_heap(data, query_vectors, k, nthreads);
491481
return r;
492482
});
493483

apis/python/src/tiledb/vector_search/module.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def ivf_query_ram(
188188
ids: "Vector",
189189
nprobe: int,
190190
k_nn: int,
191-
nth: bool,
192191
nthreads: int,
193192
ctx: "Ctx" = None,
194193
use_nuv_implementation: bool = False,
@@ -214,8 +213,6 @@ def ivf_query_ram(
214213
Number of probs
215214
k_nn: int
216215
Number of nn
217-
nth: bool
218-
Return nth records
219216
nthreads: int
220217
Number of theads
221218
ctx: Ctx
@@ -233,7 +230,6 @@ def ivf_query_ram(
233230
ids,
234231
nprobe,
235232
k_nn,
236-
nth,
237233
nthreads,
238234
]
239235
)
@@ -262,7 +258,6 @@ def ivf_query(
262258
nprobe: int,
263259
k_nn: int,
264260
memory_budget: int,
265-
nth: bool,
266261
nthreads: int,
267262
ctx: "Ctx" = None,
268263
use_nuv_implementation: bool = False,
@@ -290,8 +285,6 @@ def ivf_query(
290285
Number of nn
291286
memory_budget: int
292287
Main memory budget
293-
nth: bool
294-
Return nth records
295288
nthreads: int
296289
Number of theads
297290
ctx: Ctx
@@ -311,7 +304,6 @@ def ivf_query(
311304
nprobe,
312305
k_nn,
313306
memory_budget,
314-
nth,
315307
nthreads,
316308
]
317309
)

apis/python/test/test_ingestion.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_flat_ingestion_u8(tmp_path):
2525
index_uri=index_uri,
2626
source_uri=os.path.join(dataset_dir, "data.u8bin"),
2727
)
28-
result = index.query(query_vectors, k=k)
28+
_, result = index.query(query_vectors, k=k)
2929
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
3030

3131

@@ -44,11 +44,11 @@ def test_flat_ingestion_f32(tmp_path):
4444
index_uri=index_uri,
4545
source_uri=os.path.join(dataset_dir, "data.f32bin"),
4646
)
47-
result = index.query(query_vectors, k=k)
47+
_, result = index.query(query_vectors, k=k)
4848
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
4949

5050
index_ram = FlatIndex(uri=index_uri)
51-
result = index_ram.query(query_vectors, k=k)
51+
_, result = index_ram.query(query_vectors, k=k)
5252
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
5353

5454

@@ -71,7 +71,7 @@ def test_flat_ingestion_external_id_u8(tmp_path):
7171
source_uri=os.path.join(dataset_dir, "data.u8bin"),
7272
external_ids=external_ids
7373
)
74-
result = index.query(query_vectors, k=k)
74+
_, result = index.query(query_vectors, k=k)
7575
assert accuracy(result, gt_i, external_ids_offset=external_ids_offset) > MINIMUM_ACCURACY
7676

7777

@@ -96,22 +96,22 @@ def test_ivf_flat_ingestion_u8(tmp_path):
9696
partitions=partitions,
9797
input_vectors_per_work_item=int(size / 10),
9898
)
99-
result = index.query(query_vectors, k=k, nprobe=nprobe)
99+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
100100
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
101101

102102
index_ram = IVFFlatIndex(uri=index_uri, memory_budget=int(size / 10))
103-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
103+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
104104
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
105105

106-
result = index_ram.query(
106+
_, result = index_ram.query(
107107
query_vectors,
108108
k=k,
109109
nprobe=nprobe,
110110
use_nuv_implementation=True,
111111
)
112112
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
113113

114-
result = index_ram.query(
114+
_, result = index_ram.query(
115115
query_vectors,
116116
k=k,
117117
nprobe=nprobe,
@@ -144,26 +144,26 @@ def test_ivf_flat_ingestion_f32(tmp_path):
144144
input_vectors_per_work_item=int(size / 10),
145145
)
146146

147-
result = index.query(query_vectors, k=k, nprobe=nprobe)
147+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
148148
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
149149

150150
index_ram = IVFFlatIndex(uri=index_uri, memory_budget=int(size / 10))
151-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
151+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
152152
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
153153

154154
index_ram = IVFFlatIndex(uri=index_uri, memory_budget=int(size / 10))
155-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
155+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
156156
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
157157

158-
result = index_ram.query(
158+
_, result = index_ram.query(
159159
query_vectors,
160160
k=k,
161161
nprobe=nprobe,
162162
use_nuv_implementation=True,
163163
)
164164
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
165165

166-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
166+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
167167
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
168168

169169

@@ -186,26 +186,26 @@ def test_ivf_flat_ingestion_fvec(tmp_path):
186186
source_uri=source_uri,
187187
partitions=partitions,
188188
)
189-
result = index.query(query_vectors, k=k, nprobe=nprobe)
189+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
190190
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
191191

192192
# Test single query vector handling
193-
result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
193+
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
194194
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
195195

196196
index_ram = IVFFlatIndex(uri=index_uri)
197-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
197+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
198198
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
199199

200-
result = index_ram.query(
200+
_, result = index_ram.query(
201201
query_vectors,
202202
k=k,
203203
nprobe=nprobe,
204204
use_nuv_implementation=True,
205205
)
206206
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
207207

208-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
208+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
209209
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
210210

211211

@@ -230,26 +230,26 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
230230
input_vectors=input_vectors,
231231
partitions=partitions,
232232
)
233-
result = index.query(query_vectors, k=k, nprobe=nprobe)
233+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
234234
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
235235

236236
# Test single query vector handling
237-
result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
237+
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
238238
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
239239

240240
index_ram = IVFFlatIndex(uri=index_uri)
241-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
241+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
242242
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
243243

244-
result = index_ram.query(
244+
_, result = index_ram.query(
245245
query_vectors,
246246
k=k,
247247
nprobe=nprobe,
248248
use_nuv_implementation=True,
249249
)
250250
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
251251

252-
result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
252+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
253253
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
254254
def test_ivf_flat_ingestion_external_ids_numpy(tmp_path):
255255
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
@@ -275,5 +275,5 @@ def test_ivf_flat_ingestion_external_ids_numpy(tmp_path):
275275
partitions=partitions,
276276
external_ids=external_ids
277277
)
278-
result = index.query(query_vectors, k=k, nprobe=nprobe)
278+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
279279
assert accuracy(result, gt_i, external_ids_offset) > MINIMUM_ACCURACY

0 commit comments

Comments
 (0)