Skip to content

Commit 3e0d06b

Browse files
Merge pull request #143 from TileDB-Inc/npapa/fix-uint64-ids
Fix query errors in tables with updates
2 parents 187c831 + 0189b4b commit 3e0d06b

File tree

5 files changed

+42
-28
lines changed

5 files changed

+42
-28
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def query(self, queries: np.ndarray, k, **kwargs):
7070
if res in updated_ids:
7171
internal_results_d[query_id, res_id] = MAX_FLOAT_32
7272
internal_results_i[query_id, res_id] = MAX_UINT64
73+
if (
74+
internal_results_d[query_id, res_id] == 0
75+
and internal_results_i[query_id, res_id] == 0
76+
):
77+
internal_results_d[query_id, res_id] = MAX_FLOAT_32
78+
internal_results_i[query_id, res_id] = MAX_UINT64
7379
res_id += 1
7480
query_id += 1
7581
sort_index = np.argsort(internal_results_d, axis=1)

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -695,11 +695,11 @@ def read_additions(
695695
logger.debug(
696696
"Reading additions vectors"
697697
)
698-
updates_array = tiledb.open(updates_uri, mode="r")
699-
q = updates_array.query(attrs=('vector',), coords=True)
700-
data = q[:]
701-
additions_filter = [len(item) > 0 for item in data["vector"]]
702-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
698+
with tiledb.open(updates_uri, mode="r") as updates_array:
699+
q = updates_array.query(attrs=('vector',), coords=True)
700+
data = q[:]
701+
additions_filter = [len(item) > 0 for item in data["vector"]]
702+
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
703703

704704
def read_updated_ids(
705705
updates_uri: str,
@@ -713,10 +713,10 @@ def read_updated_ids(
713713
logger.debug(
714714
"Reading updated vector ids"
715715
)
716-
updates_array = tiledb.open(updates_uri, mode="r")
717-
q = updates_array.query(attrs=('vector',), coords=True)
718-
data = q[:]
719-
return data["external_id"]
716+
with tiledb.open(updates_uri, mode="r") as updates_array:
717+
q = updates_array.query(attrs=('vector',), coords=True)
718+
data = q[:]
719+
return data["external_id"]
720720

721721
def read_input_vectors(
722722
source_uri: str,
@@ -1729,10 +1729,15 @@ def consolidate_and_vacuum(
17291729
config: Optional[Mapping[str, Any]] = None,
17301730
):
17311731
group = tiledb.Group(index_group_uri)
1732-
if INPUT_VECTORS_ARRAY_NAME in group:
1733-
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
1734-
if EXTERNAL_IDS_ARRAY_NAME in group:
1735-
tiledb.Array.delete_array(group[EXTERNAL_IDS_ARRAY_NAME].uri)
1732+
try:
1733+
if INPUT_VECTORS_ARRAY_NAME in group:
1734+
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
1735+
if EXTERNAL_IDS_ARRAY_NAME in group:
1736+
tiledb.Array.delete_array(group[EXTERNAL_IDS_ARRAY_NAME].uri)
1737+
except tiledb.TileDBError as err:
1738+
message = str(err)
1739+
if "does not exist" not in message:
1740+
raise err
17361741
modes = ["fragment_meta", "commits", "array_meta"]
17371742
for mode in modes:
17381743
conf = tiledb.Config(config)

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
402402
const std::vector<uint64_t> &ids,
403403
int k,
404404
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
405-
auto r = detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
405+
auto r = detail::flat::vq_query_heap<tdbColMajorMatrix<T>, ColMajorMatrix<float>, uint64_t>(data, query_vectors, ids, k, nthreads);
406406
return r;
407407
});
408408
}
@@ -415,7 +415,7 @@ static void declare_vq_query_heap_pyarray(py::module& m, const std::string& suff
415415
const std::vector<uint64_t> &ids,
416416
int k,
417417
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
418-
auto r = detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
418+
auto r = detail::flat::vq_query_heap<ColMajorMatrix<T>, ColMajorMatrix<float>, uint64_t>(data, query_vectors, ids, k, nthreads);
419419
return r;
420420
});
421421
}

apis/python/test/test_ingestion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tiledb.cloud.dag import Mode
1010

1111
MINIMUM_ACCURACY = 0.85
12+
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1213

1314

1415
def test_flat_ingestion_u8(tmp_path):
@@ -307,11 +308,12 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
307308
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
308309
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
309310

311+
update_ids_offset = MAX_UINT64-size
310312
updated_ids = {}
311313
for i in range(100):
312314
index.delete(external_id=i)
313-
index.update(vector=data[i].astype(dtype), external_id=i + 1000000)
314-
updated_ids[i + 1000000] = i
315+
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset)
316+
updated_ids[i + update_ids_offset] = i
315317

316318
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
317319
assert accuracy(result, gt_i, updated_ids=updated_ids) > MINIMUM_ACCURACY
@@ -346,9 +348,10 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
346348

347349
update_ids = {}
348350
updated_ids = {}
351+
update_ids_offset = MAX_UINT64 - size
349352
for i in range(0, 100000, 2):
350-
update_ids[i] = i + 1000000
351-
updated_ids[i + 1000000] = i
353+
update_ids[i] = i + update_ids_offset
354+
updated_ids[i + update_ids_offset] = i
352355
external_ids = np.zeros((len(update_ids) * 2), dtype=np.uint64)
353356
updates = np.empty((len(update_ids) * 2), dtype='O')
354357
id = 0

src/include/detail/flat/vq.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ auto vq_query_heap(
8383
unsigned nthreads) {
8484
// @todo Need to get the total number of queries, not just the first block
8585
// @todo Use Matrix here rather than vector of vectors
86-
std::vector<std::vector<fixed_min_pair_heap<float, unsigned>>> scores(
86+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
8787
nthreads,
88-
std::vector<fixed_min_pair_heap<float, unsigned>>(
89-
size(q), fixed_min_pair_heap<float, unsigned>(k_nn)));
88+
std::vector<fixed_min_pair_heap<float, Index>>(
89+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
9090

9191
unsigned size_q = size(q);
9292
auto par = stdx::execution::indexed_parallel_policy{nthreads};
@@ -184,10 +184,10 @@ auto vq_query_heap_tiled(
184184
unsigned nthreads) {
185185
// @todo Need to get the total number of queries, not just the first block
186186
// @todo Use Matrix here rather than vector of vectors
187-
std::vector<std::vector<fixed_min_pair_heap<float, unsigned>>> scores(
187+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
188188
nthreads,
189-
std::vector<fixed_min_pair_heap<float, unsigned>>(
190-
size(q), fixed_min_pair_heap<float, unsigned>(k_nn)));
189+
std::vector<fixed_min_pair_heap<float, Index>>(
190+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
191191

192192
unsigned size_q = size(q);
193193
auto par = stdx::execution::indexed_parallel_policy{nthreads};
@@ -261,10 +261,10 @@ auto vq_query_heap_2(
261261
unsigned nthreads) {
262262
// @todo Need to get the total number of queries, not just the first block
263263
// @todo Use Matrix here rather than vector of vectors
264-
std::vector<std::vector<fixed_min_pair_heap<float, size_t>>> scores(
264+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
265265
nthreads,
266-
std::vector<fixed_min_pair_heap<float, size_t>>(
267-
size(q), fixed_min_pair_heap<float, size_t>(k_nn)));
266+
std::vector<fixed_min_pair_heap<float, Index>>(
267+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
268268

269269
unsigned size_q = size(q);
270270
auto par = stdx::execution::indexed_parallel_policy{nthreads};

0 commit comments

Comments
 (0)