Skip to content

Commit 6b118ef

Browse files
author
Nikos Papailiou
committed
Format
1 parent a39410b commit 6b118ef

File tree

1 file changed

+58
-40
lines changed

1 file changed

+58
-40
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Index:
2121
config: Optional[Mapping[str, Any]]
2222
config dictionary, defaults to None
2323
"""
24+
2425
def __init__(
2526
self,
2627
uri: str,
@@ -43,20 +44,26 @@ def query(self, queries: np.ndarray, k, **kwargs):
4344
return self.query_internal(queries, k, **kwargs)
4445

4546
# Query with updates
46-
updated_ids = set(self.read_updated_ids())
47-
retrieval_k = k
48-
if len(updated_ids) > 0:
49-
retrieval_k = 2*k
50-
5147
# Perform the queries in parallel
52-
kwargs["nthreads"] = int(mp.cpu_count()/2)
48+
retrieval_k = 2 * k
49+
kwargs["nthreads"] = int(mp.cpu_count() / 2)
5350
parent_conn, child_conn = mp.Pipe()
5451
p = mp.Process(
5552
target=Index.query_additions,
56-
args=(child_conn, queries, k, self.dtype, self.update_arrays_uri, int(mp.cpu_count()/2)))
53+
args=(
54+
child_conn,
55+
queries,
56+
k,
57+
self.dtype,
58+
self.update_arrays_uri,
59+
int(mp.cpu_count() / 2),
60+
),
61+
)
5762
p.start()
58-
internal_results_d, internal_results_i = self.query_internal(queries, retrieval_k, **kwargs)
59-
addition_results_d, addition_results_i = parent_conn.recv()
63+
internal_results_d, internal_results_i = self.query_internal(
64+
queries, retrieval_k, **kwargs
65+
)
66+
addition_results_d, addition_results_i, updated_ids = parent_conn.recv()
6067
p.join()
6168

6269
# Filter updated vectors
@@ -81,13 +88,15 @@ def query(self, queries: np.ndarray, k, **kwargs):
8188
for query in addition_results_d:
8289
res_id = 0
8390
for res in query:
84-
if addition_results_d[query_id, res_id] == 0 and addition_results_i[query_id, res_id] == 0:
91+
if (
92+
addition_results_d[query_id, res_id] == 0
93+
and addition_results_i[query_id, res_id] == 0
94+
):
8595
addition_results_d[query_id, res_id] = MAX_FLOAT_32
8696
addition_results_i[query_id, res_id] = MAX_UINT64
8797
res_id += 1
8898
query_id += 1
8999

90-
91100
results_d = np.hstack((internal_results_d, addition_results_d))
92101
results_i = np.hstack((internal_results_i, addition_results_i))
93102
sort_index = np.argsort(results_d, axis=1)
@@ -96,95 +105,103 @@ def query(self, queries: np.ndarray, k, **kwargs):
96105
return results_d[:, 0:k], results_i[:, 0:k]
97106

98107
@staticmethod
99-
def query_additions(conn, queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8):
108+
def query_additions(
109+
conn, queries: np.ndarray, k, dtype, update_arrays_uri, nthreads=8
110+
):
100111
assert queries.dtype == np.float32
101-
additions_vectors, additions_external_ids = Index.read_additions(update_arrays_uri)
112+
additions_vectors, additions_external_ids, updated_ids = Index.read_additions(
113+
update_arrays_uri
114+
)
102115
if additions_vectors is None:
103-
return None, None
116+
conn.send(None, None, updated_ids)
117+
conn.close()
118+
return
119+
104120
queries_m = array_to_matrix(np.transpose(queries))
105121
d, i = query_vq_heap_pyarray(
106122
array_to_matrix(np.transpose(additions_vectors).astype(dtype)),
107123
queries_m,
108124
StdVector_u64(additions_external_ids),
109125
k,
110-
nthreads)
111-
conn.send((np.transpose(np.array(d)), np.transpose(np.array(i))))
126+
nthreads,
127+
)
128+
conn.send((np.transpose(np.array(d)), np.transpose(np.array(i)), updated_ids))
112129
conn.close()
113130

114131
@staticmethod
115132
def read_additions(update_arrays_uri) -> (np.ndarray, np.array):
116133
if update_arrays_uri is None:
117-
return None, None
134+
return None, None, np.array([], np.uint64)
118135
updates_array = tiledb.open(update_arrays_uri, mode="r")
119-
q = updates_array.query(attrs=('vector',), coords=True)
136+
q = updates_array.query(attrs=("vector",), coords=True)
120137
data = q[:]
121138
updates_array.close()
139+
updated_ids = data["external_id"]
122140
additions_filter = [len(item) > 0 for item in data["vector"]]
123141
if len(data["external_id"][additions_filter]) > 0:
124-
return np.vstack(data["vector"][additions_filter]), data["external_id"][additions_filter]
142+
return (
143+
np.vstack(data["vector"][additions_filter]),
144+
data["external_id"][additions_filter],
145+
updated_ids
146+
)
125147
else:
126-
return None, None
148+
return None, None, updated_ids
127149

128150
def query_internal(self, queries: np.ndarray, k, **kwargs):
129151
raise NotImplementedError
130152

131153
def update(self, vector: np.array, external_id: np.uint64):
132154
updates_array = self.open_updates_array()
133-
vectors = np.empty((1), dtype='O')
155+
vectors = np.empty((1), dtype="O")
134156
vectors[0] = vector
135-
updates_array[external_id] = {'vector': vectors}
157+
updates_array[external_id] = {"vector": vectors}
136158
updates_array.close()
137159
self.consolidate_update_fragments()
138160

139161
def update_batch(self, vectors: np.ndarray, external_ids: np.array):
140162
updates_array = self.open_updates_array()
141-
updates_array[external_ids] = {'vector': vectors}
163+
updates_array[external_ids] = {"vector": vectors}
142164
updates_array.close()
143165
self.consolidate_update_fragments()
144166

145167
def delete(self, external_id: np.uint64):
146168
updates_array = self.open_updates_array()
147-
deletes = np.empty((1), dtype='O')
169+
deletes = np.empty((1), dtype="O")
148170
deletes[0] = np.array([], dtype=self.dtype)
149-
updates_array[external_id] = {'vector': deletes}
171+
updates_array[external_id] = {"vector": deletes}
150172
updates_array.close()
151173
self.consolidate_update_fragments()
152174

153175
def delete_batch(self, external_ids: np.array):
154176
updates_array = self.open_updates_array()
155-
deletes = np.empty((len(external_ids)), dtype='O')
177+
deletes = np.empty((len(external_ids)), dtype="O")
156178
for i in range(len(external_ids)):
157179
deletes[i] = np.array([], dtype=self.dtype)
158-
updates_array[external_ids] = {'vector': deletes}
180+
updates_array[external_ids] = {"vector": deletes}
159181
updates_array.close()
160182
self.consolidate_update_fragments()
161183

162184
def consolidate_update_fragments(self):
163185
fragments_info = tiledb.array_fragments(self.update_arrays_uri)
164-
if(len(fragments_info) > 10):
186+
if len(fragments_info) > 10:
165187
tiledb.consolidate(self.update_arrays_uri)
166188
tiledb.vacuum(self.update_arrays_uri)
167189

168190
def get_updates_uri(self):
169191
return self.update_arrays_uri
170192

171-
def read_updated_ids(self) -> np.array:
172-
if self.update_arrays_uri is None:
173-
return np.array([], np.uint64)
174-
updates_array = tiledb.open(self.update_arrays_uri, mode="r")
175-
q = updates_array.query(attrs=('vector',), coords=True)
176-
data = q[:]
177-
updates_array.close()
178-
return data["external_id"]
179-
180193
def open_updates_array(self):
181194
if self.update_arrays_uri is None:
182-
updates_array_name = storage_formats[self.storage_version]["UPDATES_ARRAY_NAME"]
195+
updates_array_name = storage_formats[self.storage_version][
196+
"UPDATES_ARRAY_NAME"
197+
]
183198
updates_array_uri = f"{self.group.uri}/{updates_array_name}"
184199
if tiledb.array_exists(updates_array_uri):
185200
raise RuntimeError(f"Array {updates_array_uri} already exists.")
186201
external_id_dim = tiledb.Dim(
187-
name="external_id", domain=(0, MAX_UINT64-1), dtype=np.dtype(np.uint64)
202+
name="external_id",
203+
domain=(0, MAX_UINT64 - 1),
204+
dtype=np.dtype(np.uint64),
188205
)
189206
dom = tiledb.Domain(external_id_dim)
190207
vector_attr = tiledb.Attr(name="vector", dtype=self.dtype, var=True)
@@ -205,13 +222,14 @@ def open_updates_array(self):
205222

206223
def consolidate_updates(self):
207224
from tiledb.vector_search.ingestion import ingest
225+
208226
new_index = ingest(
209227
index_type=self.index_type,
210228
index_uri=self.uri,
211229
size=self.size,
212230
source_uri=self.db_uri,
213231
external_ids_uri=self.ids_uri,
214-
updates_uri=self.update_arrays_uri
232+
updates_uri=self.update_arrays_uri,
215233
)
216234
tiledb.Array.delete_array(self.update_arrays_uri)
217235
self.group.close()

0 commit comments

Comments
 (0)