Skip to content

Commit a0ce15d

Browse files
author
Nikos Papailiou
committed
Add automatic update array fragment consolidation and fix query corner cases
1 parent 2985579 commit a0ce15d

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,43 @@ def __init__(
4040

4141
def query(self, queries: np.ndarray, k, **kwargs):
4242
updated_ids = set(self.read_updated_ids())
43-
internal_results_d, internal_results_i = self.query_internal(queries, k, **kwargs)
43+
retrieval_k = k
44+
if len(updated_ids) > 0:
45+
retrieval_k = 2*k
46+
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
4447
if self.update_arrays_uri is None:
45-
return internal_results_d, internal_results_i
48+
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
4649

47-
addition_results_d, addition_results_i = self.query_additions(queries, k)
4850
# Filter updated vectors
4951
query_id = 0
5052
for query in internal_results_i:
5153
res_id = 0
5254
for res in query:
5355
if res in updated_ids:
5456
internal_results_d[query_id, res_id] = MAX_FLOAT_32
55-
internal_results_i[query_id, res_id] = 0
57+
internal_results_i[query_id, res_id] = MAX_UINT64
5658
res_id += 1
5759
query_id += 1
60+
sort_index = np.argsort(internal_results_d, axis=1)
61+
internal_results_d = np.take_along_axis(internal_results_d, sort_index, axis=1)
62+
internal_results_i = np.take_along_axis(internal_results_i, sort_index, axis=1)
63+
5864
# Merge update results
65+
addition_results_d, addition_results_i = self.query_additions(queries, k)
66+
if addition_results_d is None:
67+
return internal_results_d[:, 0:k], internal_results_i[:, 0:k]
68+
69+
query_id = 0
70+
for query in addition_results_d:
71+
res_id = 0
72+
for res in query:
73+
if addition_results_d[query_id, res_id] == 0 and addition_results_i[query_id, res_id] == 0:
74+
addition_results_d[query_id, res_id] = MAX_FLOAT_32
75+
addition_results_i[query_id, res_id] = MAX_UINT64
76+
res_id += 1
77+
query_id += 1
78+
79+
5980
results_d = np.hstack((internal_results_d, addition_results_d))
6081
results_i = np.hstack((internal_results_i, addition_results_i))
6182
sort_index = np.argsort(results_d, axis=1)
@@ -69,6 +90,8 @@ def query_internal(self, queries: np.ndarray, k, **kwargs):
6990
def query_additions(self, queries: np.ndarray, k):
7091
assert queries.dtype == np.float32
7192
additions_vectors, additions_external_ids = self.read_additions()
93+
if additions_vectors is None:
94+
return None, None
7295
queries_m = array_to_matrix(np.transpose(queries))
7396
d, i = query_vq_heap_pyarray(
7497
array_to_matrix(np.transpose(additions_vectors).astype(self.dtype)),
@@ -82,16 +105,19 @@ def update(self, vector: np.array, external_id: np.uint64):
82105
updates_array = self.open_updates_array()
83106
updates_array[external_id] = vector
84107
updates_array.close()
108+
self.consolidate_update_fragments()
85109

86110
def update_batch(self, vectors: np.ndarray, external_ids: np.array):
87111
updates_array = self.open_updates_array()
88112
updates_array[external_ids] = {'vector': vectors}
89113
updates_array.close()
114+
self.consolidate_update_fragments()
90115

91116
def delete(self, external_id: np.uint64):
92117
updates_array = self.open_updates_array()
93118
updates_array[external_id] = np.array([], dtype=self.dtype)
94119
updates_array.close()
120+
self.consolidate_update_fragments()
95121

96122
def delete_batch(self, external_ids: np.array):
97123
updates_array = self.open_updates_array()
@@ -100,6 +126,13 @@ def delete_batch(self, external_ids: np.array):
100126
deletes[i] = np.array([], dtype=self.dtype)
101127
updates_array[external_ids] = {'vector': deletes}
102128
updates_array.close()
129+
self.consolidate_update_fragments()
130+
131+
def consolidate_update_fragments(self):
132+
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
133+
if(len(fragments_info) > 10):
134+
tiledb.consolidate(self.update_arrays_uri)
135+
tiledb.vacuum(self.update_arrays_uri)
103136

104137
def get_updates_uri(self):
105138
return self.update_arrays_uri
@@ -111,8 +144,10 @@ def read_additions(self) -> (np.ndarray, np.array):
111144
q = updates_array.query(attrs=('vector',), coords=True)
112145
data = q[:]
113146
additions_filter = [len(item) > 0 for item in data["vector"]]
114-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
115-
147+
if len(data["external_id"][additions_filter]) > 0:
148+
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
149+
else:
150+
return None, None
116151
def read_updated_ids(self) -> np.array:
117152
if self.update_arrays_uri is None:
118153
return np.array([], np.uint64)

0 commit comments

Comments
 (0)